From bb03f6ade4b104c890448bb126cf43c345f26c72 Mon Sep 17 00:00:00 2001 From: bolade Date: Tue, 28 Oct 2025 20:54:15 +0100 Subject: [PATCH] fixed querying --- app/__pycache__/main.cpython-312.pyc | Bin 6698 -> 6643 bytes app/db/__pycache__/models.cpython-312.pyc | Bin 9678 -> 9715 bytes app/main.py | 11 +- .../router_schemas.cpython-312.pyc | Bin 10946 -> 10946 bytes .../__pycache__/querying.cpython-312.pyc | Bin 9236 -> 10699 bytes app/services/company_querying.py | 214 +++++++++++------- app/services/querying.py | 206 +++++++++++------ 7 files changed, 270 insertions(+), 161 deletions(-) diff --git a/app/__pycache__/main.cpython-312.pyc b/app/__pycache__/main.cpython-312.pyc index 26e197f348ebc181e40767c82105890480c71fa1..9932cb90f90797ea03be172b366a12a69d081d4e 100644 GIT binary patch delta 1395 zcmZ`(%}*Og6rb7k7Y5r9{QZF!Y$${j6r4oRl1OPPDy0ook|v6LLC(&AmzXtqYe#Y` z1yyzLMRUn1xrV?Ys-?F7L8Vqgsd|Zg>ZQj>G(@VDUixN@9gLin_BZ=`^WN{heQ$Q3 zojc6=e)svj1h0I88sBXDV#+V-_T>H+>#vi1Mj<@2D|hjt1ld5F<+BR;jIe<^VLm>z ztG0adyj`AdWZ{0lZc;s}74X;stnoq~juorA6 z#j<=X)Rsc#f5%o87Upgi;Q+dmRO^E$#x=^;^Qbd zQg?JDX5J%_=1Cs!sW5>G+A)P$o)teRK?e`7-xd36iXPG7ZW`&P!EPGrrqLF?JjX`w zO%eieMi?#S@;PiyZoSx}wPP&9vU}t0H>IRTYjkR&@yL;(jnl3Kb*3>YF0Q&YniBWi zKhZHU=DA9D#D-_#RucChC1C`i>AF%jSB*8uBEj^;iXzm>Ujw0+gYg&1y9cTg^&hB? zl=lx$tQ9wn9?Lr}_yAXsF)$_m^TZdW9lGUVRN5qr<9#@U@I86yANuavzYeF<`!rrTNu zw=~1lDn+iX8fJy-Ma|-Rr3_lBtQRVr$=#a;Q(J<9sT-DFo>_(osdnN;rnjUqhcJit zi7u_YuS;;IhtyeVlQ1hEY_R7cO>gO}j!WlK@JW8JhhDMSmmal?>aMKmD)K_9QiOSQ zTM*Ct!Zg+RweN|u|6?@3qJ)d$e&B4#o_F770W7|vTMytGZe0>T2VS{8l1>YaXM>YU zTK0QK1^Ekr&n0|;V96RZO@k5fΖeu$+;|8PN>yDnX14BS9?=ZB;jzzITar85b* zjn;3)jcPS3wI>g$42&}sjl$~%!^ti(J$)-nHp}3 zCy~UweE^&8>H|YBgV|J}YMz+dThHP~nN>^t5^Q3e{DE12Q}HUYB(|ez|C^wPdLK=_ XA-Eg_#gEb7)EAix;$iHmLk{GBb9x@u delta 1649 zcmaJ>O-vg{6rNq%YiwgXjsatX%Yp-#VnUMu!Ren+6%q)4ZBkW6YTT`7Y}VMjbassd zhe8yp(o5QOs){PA=&fquQgQ3Cx3*FRq7rCNIpkU?O4U}%C4FPZlmyC3^LgL<-uLai zneqI2;rl-CuU@Yk!Smp!Co|?4i6}0iRV7iwi*n48#4P1xoa{@U>&WC~8 zY^d77x4bx1l|tnEFUpFzkvMUf_;G|>$V#|r#c^QVS1~?|n@KZi;-)wXEB8Xx7`Kof zFpyX52a^7ZWEi)CB=WzKwnLHuXjVVkENRHku_;?Ld^{UGkqw>5)}6@KS6CZ-m(AJg zh4L(hI&o|^ju4GtgoQ$KItlBT%$rB&jnCkA+_BkNT_gc52n#}2eCvTNCM+%XwZ|1B zW2!~cWyn2pdt6fvUJGQN6b{8jVi-ye3cGBUqfA;-R;NY9v}$G)4HpPC7JXU0Fe@u_ zsys^;F3DEy@Z{LF^w88u`r74G=J5RByl;gZ9>EcVY#Zy8WY$|V&c2ZHLI-;yjS5|C z)R}CSfu*e++5jx+4PDEsISS3Am)N?qb*apqg>a%vqA&8oZVm|m%VwC=@~6{^rs<|) zs=8*R)07uenZtoEZRa6}9vDUd;3Kr}3)~+6$L?^rA0dEBHLrOr^d7XresuK=%8-1{ z(UN9}QZQ#|KX8LE3@^Ykx?dao?44~-blV=Kty?Ybb%%$uG|Ax-``Lf4zTBv5PN4=# zXN-9|3e>C2f9j9Z7kSYDfaRp*gAy@J_Cv6=bvy8_knDi;cQ(o|4$3baIs#z8DkE8z zD~sm`30+95+B_(AYS1ZQXSXszMLa*jCESBf0vxN$=*2ljQ;9+O-LuU8uJ04xWZezP z=*x5nGM2NT=W?o+vqasnB&fxTL0@OL8Xn!WJnt)pnyKubWgnRsp0vEzO-hvF)ymni zTsdNnfdQfAZL}orMyO?1wXAMQz!Ghq^3y`!=D>HtvO${mvpb=t6kn!tN9E~RK2^&U zN|>anUNXT(div`u&8r!mYL<0t{n7WX{InoBJrZ?A$}a z9TePm2R==FlGyhQ?fYV@xjkQO#~0i6wQqYn{&Ly{@u3^Jn^r@6uE>rnvNpKuISVYT5pmI{Aq=ZEuO D@vLRA diff --git a/app/db/__pycache__/models.cpython-312.pyc b/app/db/__pycache__/models.cpython-312.pyc index 5bcb5acd548b3e41a1918eab4718d69247394e79..e04c3508d764232f29bac2ad0807225d2c15c6eb 100644 GIT binary patch delta 248 zcmX@-{n?xEG%qg~0}zB9`JW-HxsgvwmPv$hvb}`FWPT~J$qCZ@lhb9rCjaFWspn5& zTEn`U38a{TAxeOWA(bUn5JgM~iW&bHItXgWioD^EFkYUd7e1`=0tg0CPv%M^A!bH ii&1KFps>p1cLH3K=gKKgJ}j5YxN@?lyx-<#c?Tv&o6ScQ1z6<88QCV7 c%#i!S0AzeI5n|MxVX{E(D+7>GBoDM60IUdMK3pG%qg~0}zB9`M;5SyA}XPoCe(h delta 20 acmX>UdMK3pG%qg~0}wzqnZp*ZC2aLt!^PfWQf3UJ_6InbFhy=n+NV3Z;!(2q1;Dq^TMBt=~lgw-&09Z&ei20H=L8|j`<0a=qcXu;5C4x>;B^Iw zTAheX$rDudRV{!oOR4M6>R3H%xMRGnP$iI0YZ*Y0(^`wgDubX+P368U%xzYua)ow9 zYTY7ona3pi_{pKMiSaY1hbDXjCr5n8ho<5UbD}seczSxGpa6U{CUU%+3G^`Yfu0My z<91vM+1YuqdvCNm3_8JM1qD4R0utYOZc#E%GXm#}@gd0=jJO`?4 z5*7t0iM6t);rDs6D>5)N! z#yzogG(yJ&j;DF<^;m%Cux1Kwt96hO8B9P=`scVXT>$_cdtk6v&GSQp)m1xrs3jSd|-O%IMto*o&4#0=C~ z_qgfdD6ffP(L+c_M~)BC?QNh23sIhJYo}*`W}e~20289ajOd>WL}slt4V>GDVmxS5 zJN>VJ{$u*=v7w0}jhQM}T45(swzVq_2f`d5@H2l*ONBJoR0#ZW?SKv%^~#e}SL(&_ zkuha$$qALWj5Gr(`NRsz13m*XBs0OnAiOkDcxb_mmr9-KB*%f&tb3JyRej2JtSAQQ zPDvN)gfz_Wh4XLfXE+cMIsY7VwNLJfs6CC<9=2ErHof4H`ECPcfW7tdw!C%9+aqL0 z(0%>%%L;k)!c(Zpj-n6csHZt`f#V|d%Va`R-_uF189gvIn7>7}?x(E#-82@orloLL zFc+c$mS$t~Ay5~Q6X+Q*b#yQqhzQmLZhAu2eg>u;nGL}~fXz*agX2cgcyj7ktwR?X zTu$R7KL9({ddfp@!z-mMrh0w|pa8V2tJ|yt9+itWBD``s6RWNU(HjAGCbVgP%Bk_j{%uwjNj z&qfy_Zj6#KE{Iz%0kwb$mRx2Y`w)^YDiFZTMJ`DCP$(=F6NWxN18dJ)Fi8nMEPP>r zci2ozR^jy!2B-j)F_;AqJ0oCVM}IfkGZzhWJ<(8Xff1NUIM6c)pn#x|p0iQ@e9wvF zgFQ?{j7EG508eIkIPU>Z&G*3AE(8ES_K?2e406u{dh!AM4TwHpyynyXoDf`#F!&_!_Btf; z;cYOV3RA+9pn^s@mH)o2WP%zX&DCmkconuhYTa!GQU^8onXDf%On-eF~VFZFfHlB>^{jz43A_KL_RPtnYf4_@T_E+<1VrR z&}wGz7;-FsSi72q)QBmh~_ab7|$^ zjYC%trRqD=^}CbxyHoYO>3UDH-jk~DOO^Gf%Z?_?j;6{+*3Bb)8?+mfteWe@`ZZSp zYOoB+TJ3>Vy8^C5P;*A0MtKsNlq}W9LkD%L~>U^Xi|$|NCUx2i-nHqNc496>HS7ig-!Lqs~Q>Ly-j~ih}iO zeeNqZs_BG5rB=U7=7c%etk&ngD)&M3p2e@5MSk5m*wIUtgoPz;_GR6>76V$cE?J&2 z7qi%+a!pv(Uauk6lCXm7UR;N+p+XdF@%cg6e>&C{p;)Y6+Imx1*YRH5@+7PYow{}) z2fTZ>JnEb%Gzs&e;D5=ofLf7=cR6{sBi_4IoHrwvT9f;-C5htA`u+_1(kE84Q{^^` z*dn%!EoO`El;2iPQ@?Hsf*MQ3Zxla!7BbxewSX?_r_ch`iZ$S=Ggm2H74ei3q4-=cFQk6r8?@ z$amq=BAx@>kY(o8h@zF9hH#1? z0BFM?M2Lk%Ne`$SMg@{5eN&u(&48v;kdUo!iJ~8kc*Rr?*o z>eTJBRMo-d;mpomKkd2IliKxKeiW z(%QjKj(jq)*1y*C554y;{;Gfdm2;mTKlk@->wa#1U}n91c0;c#Db1RZ!s|@(%!j^9sKcj=+L{ zYUmnnLZ9xk4c8hz?HjN`=I6CW{BE)wqx3&-KYX-L|Je=-(x0^(jv3HreMU&%qx6uz zXRsfu*7GNU&nBpaHJ(7NK`?;;wJ4!A7x{%FoMLqeim%AUfOD{mrS2HigMAWW)0^|6 za^wRb(BQIaq*PP6q=)#m5de$sM}dSsSO7YU)eIeMS@=(8p zU{oFl$Chp&o6w!_$Qz3WU2>G_9AQ?4XXI3AwNX5MMjCyvPHma{YCJRvXt`snJ2sdjh^T&-9YyLs3+{hYK?QY@f9!S*Ldk=s$~ zc!C9>nWk69Sp;r@aS**vor?l&Edp3nKr@!itO+2k;YX#S2FVhJgc;zN`8MA1@Ibw} zkZ!O>)>4N69$OW7+zPZ-v0H8hyTBSOFgGZ5s}#BK9jgjl0amvGtS(-%!iWp32H(pT z2Y2UMs;R=%Fglwux>6XOEnzFvJM>v@z+l*>$`&BEaj67&7Fb`ttZ=z3j|er3Wm1|a zlTr~oz~o7k7L-Y8p;0{{ld`2!ai7M?T#!>yq1#u!9#*h`)WHd{!CVU5*8;rPM~Vx# zg3^QPC~{xml+ziNf{1LcUk&2WszEru%ee+@WjdN`yJ_Qf{KH{;y z%PxzNxDY&!BS}S3;(`b~j`P4ZGr0ngO!9penS^AV3GlG=_=cW8gI^d?@jgs%6FCOQ zd&ARk`NhA6iz;M6o4haawD=*YjBnfGUls5Z z{|a;34=VN+`r3iYo%acX&s5iBs;d9Fv-f)I%G|Z?Ohe;FvB^=IMJD1qc=|S#Y}oT% z8H1OyG(W7aTRxg8t6eW^d+2CNJGLhs+kf(E%CYA`Rm0`8*I&KUmEPH(+}WRMJd~PnXlla{4E=sq)ST70$~w z_Z^T)xAZ1kdQ*v46mUcdubUyc3udAvC538!BFwLH%bI-GAx|7cC%`|<< zmcEpuU+H;IvIPhBAJ^FI_HXKt-FdnCy5)xLs!ef3ohj#@R7LND3diNj52~^)sKlP_ zK-+h$+vtq5@iFQw8KB-bWa^u5^k40N-wf_$`_A?9)=XvHdgYG$J0~)9`&!3c7M!Tg z>UBqZwiIpeO1Bdld(RHN($H%Y?sePEmp+>MaBA(Pk6-!OD_=SW zpjX=dz70C1?Hg9q{#>S}`O2lsmmY$H{l<^qfX*A+Zk)J!V&(8^>ni)x$gRkFcYm_> z(8Jo68;w^R!RyBI%e1-DZJuPCr$B=4%yc}TY3P9WMs0~a#RgQ-mM-s1mUpgq9saWX z2ida}mhFa@LiJ7f_NA_O`FN&z=jy4ye)DR}+Q9O0;->D*RJVYVXX@HA4ehdeZ}EZ8 zQ6Ri;AC;o!aq8PgHkj1!voBDaoz?@efBx=h`B(|MXYDv~$aHT{JEVV6TREn;{-Ub` z(!X?cjG0Zp?0aE1Wd89`8NC0g#6H%ncTr=m5`3$Kk0PH>vif}RtDIN}r)@q6kuV{- zhQ~ls(!=Q}2bl~c3%?VP&s}_|@e}xkEd|;6w-C8~kdFdSG@Om3P(hXBa zYxbm0XMB;mDP$4;t_qLIH)}hXM-1G@Ex&_*!+oc!Lb2D9E=UNX~IEzUEi+lHG6;h zb}!)5x#viHaHT{*FD+C_T2!?X5lE#$4%DC*ADqyqS6L&HK%~*^kHGe#!T-*Xu&?>pc3!bXUt2U%h2Nb0xgDkmO<^1jnz`#scuXLV?oNh z)VgGStUlQgYeg_Q*u*m6XiW>VRN!M*32Lal9`!4vcRW*@|K&7DT%d! zj&J73)8ph#@~-K$f{uUYzsvz+k?jP?z*grN9KQ%g3+|2AbZk^6fL>R+Ac5>*jsEZS#CXc7n>N!OR zsrv~jyt)oei~2}LPW2BGlLVGUU5*1y=y7pMPU%IRIad55vu}H@RTWcHiD^+ub%@GN z?R#nmRp`)WHC;}2WNH-Lu3 z$_|+$+@xKyUu0uW$pPc6>@vqEZ4!6U7IRZeXI1)t>|J2*nYBbbhVNwYl4YyqB2)a0 zWwf9>!9V@!6k~+VqZzv?dy?L=-?W(Tc@$q8*L_u;nGIEL{V6#l?v}G^C%JMkB)KW| zQ2L6OK75q&l-5wXfl@!E0ZMBr-AHMW(mG1(37uo~Ky_NO{Nk3H0FP!oRT+5}FV%#K zqW;PA-M$kEdRjN}Gf9l|Tqa!SZ*GS&o5oE6%X z)a$aQr?J*1q_NO8siY(!DGTf3_JW{K%d;R9Q^Ju`!-vOmz2K)`Jul8`!uPT=o+S!k znoUXI1putd!X!8|eIx;tLOPR4V?CQvbVb&3FM|CzFOuh~PE6B1lDM^LLIT4nU`34PdR#~7 z!g|PBlVXP^!_pxmB_WkfPLLpz2%I&V*QDJbB>Y;_*h$RyK7HW z7?xj$ontuzC5=W~69I^mqM|}LDb6BR+z8Mocu%XEVK+-y#Dl#=YDt4=svzrf>UF#s zB$%WU-bBV6!~zPc!J9U5sDKc7T{i3{vM?zehI2~RCuL28 zk$5%(MD9)zHrAbW+}#Rs_$D>3M8N-lkhlHHaT#DjPWOP=P7? z5#Ln_IK5rvTEur$HoCkWzU&V#`ol~99V>yzV}~v1TClz4dIV8G zg<*q(!_!Fa8ZLEKgXX88<2D9$Qs5$L5P$LL0~OY~E76UB*y^%YAkNrxSn}L0m1ASxrezQE(3p z9T!AM1~Upz)=D!?Ab&{_DE=z&PQl<-VW$~3Qi2Q~%ZV)3Af$0sNhr_O zz1O=-q4rWkN68aisoi!RmjXM=EUMl12w7`g3mhb&lfUHtf&21>%QJVmmK7m& zzCF4qbT12ii$Y(C+w>XdqdMu*_HUK8?Jw0HxXbn457aFOUMMUEUMO`RUkVH_aFCZZ zwaa|lBHwm{yS-zHKTzTh;OEKG-J-Gkg|yycmg zw&s~Cn&yt%^0ve15}N>%gyl6Tcyp{y(uMx=Y!!-EEN}8Qa0-xowo_Z4nQ;^=-p0b} z^5*P$JFd^$!Mo!L?~W(DuXVg@-W_@C7_5VV{~UY1$~n%Xtw>)>rB;MPbDX}`&`4~a zTTi6|LJT6CIgKRC9QS>0Z1s7bIgc(_Powip#GY#zFwwYvJrghr0PQNiqL^%CIOjM! zaaMZn>_m<~JE2Htdd_OE?FfS;l4~8n81A`MT$2PO*E~h_nwK6SVmB;UHXNh_5>?f( zDv%AVnm!LX;vBe;kZB8Vuo(;k5vPgw0hk&eCFBqg!?Akn8&+Tpmm$OAgaizN53on)G3cEou)bz8^PeP^Ey?5Np z-GhtWgG=4VmKuf@Uj7T;u(G%Blg9oVN8UYj-TvOOQe*#R4hUR0HT4Vpmk;Yu{hlu$ zx=_mMSEGx(WjL#ej=&W1bIl7IhQxY7a}-x~aPnT4URw0*GDg7}yW%{r$UaAMUUaKI$U0+j+3j`Vqq)3|c>`VF({&f#O3ST&jPYb3Y&-aSzDH zL8SQS+Kyt^mY3|`fj1m4&To06K)QZtmgSUTOOXR7d1tV+kYEyc!;z-vsA1Be8+NnD zfi6H%p{2vblfh0Z7|zN8JU&ARK{S@hkVlurB$P7R+u=zf+OXc{#BEI7Musij$6c)?6&Bkwy(m#8fRF*t&b`X<#+7aT9w_QCfR(U-9j& zZA|Ke-*5eewXl=;v_dE_+@c5%5<&t@_kj-+j?^@YUfc)V`-ql7(^ZfR5kfjJ?F}Fo z(L+l!>na~+P)~2gTF3UT)HjytGtzl`>-(E1f8QA@HFhjIqnjVGpi=ItW$W+kDkFIO zVfUY`AMSfZ*mAV^pSC?sr0UX1B4G%$20ey@+-s2Jcw2Eeyrp0wz$r~clW9p-HS<{M ziBdc`*S?Z50hQ6R_8uoKJ^5srjnZC_HipncYP7F3Ha$x*t1Bw~#=+!U&!|y%RKLOF zX*g9;FYcVktou*k=ynwZa)MjtIECjNAOCr zOUnZ(TN#G=6uCb|zE6?!bF}3jX#3}A|03Fd4>jCF{o*~e=N@XihdRHou#Dvq0`gA} zV#0q%C*cGa>Xx0MMQ7+*_mZ=9-hSU7y1bS2i%>z9{0Hax6}Dmi%re`s$TnPSUt-(l REnnFCtxWSnL?~U}{{R$^{2BlN diff --git a/app/services/company_querying.py b/app/services/company_querying.py index 80a2b71..e002af3 100644 --- a/app/services/company_querying.py +++ b/app/services/company_querying.py @@ -1,21 +1,17 @@ +import hashlib import logging import os from typing import List -from db.db import DATABASE_URL, get_db +from db.db import get_db from db.models import CompanyTable -from langchain import hub -from langchain_community.agent_toolkits import SQLDatabaseToolkit -from langchain_community.utilities import SQLDatabase +from langchain_core.prompts import ChatPromptTemplate from langchain_openai import ChatOpenAI -from langgraph.prebuilt import create_react_agent from schemas.router_schemas import CompanyData, PaginatedResponse +from sqlalchemy import text from sqlalchemy.orm import selectinload logger = logging.getLogger(__name__) -# Connect to SQLite -prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt") -db = SQLDatabase.from_uri(DATABASE_URL) class CompanyQueryProcessor: @@ -26,96 +22,144 @@ class CompanyQueryProcessor: model="openai/gpt-4o-mini", temperature=0, ) - self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm) - # Update system message to specifically request only company IDs - system_message_updated = ( - prompt_template.format(dialect="SQLite", top_k=5) - + "\n\n=== CRITICAL INSTRUCTIONS ===" - + "\n- Your ONLY task is to run SQL queries and extract company IDs" - + "\n- When you get SQL results with company IDs, return them EXACTLY as shown" - + "\n- If the SQL query returns rows with company IDs like [(1,), (5,), (9,)], return all those IDs" - + "\n- Do NOT add any explanations, just list the IDs" - + "\n- If a query returns NO ROWS (empty result), then respond with 'NO_RESULTS'" - + "\n\n=== QUERY GUIDELINES ===" - + "\n1. For sector searches: SELECT companies.id FROM companies JOIN company_sector ON companies.id = company_sector.company_id JOIN sectors ON company_sector.sector_id = sectors.id WHERE sectors.name LIKE '%sector_name%'" - + "\n2. For industry searches: WHERE companies.industry LIKE '%search_term%'" - + "\n3. For location searches: WHERE companies.location LIKE '%location%'" - + "\n4. For founding year searches: WHERE companies.founded_year >= year" - + "\n5. For investor-related: JOIN investor_companies table" - ) - self.agent = create_react_agent( - model=self.llm, - tools=self.toolkit.get_tools(), - prompt=system_message_updated, + + # Query cache for performance + self.query_cache = {} + + # SQL generation prompt + self.sql_prompt = ChatPromptTemplate.from_messages( + [ + ( + "system", + """You are a SQL expert. Generate a SQLite query to find company IDs based on user requirements. + +Database Schema: +- companies: id, name, industry, location, description, founded_year, website +- company_sector: company_id, sector_id +- sectors: id, name +- investor_companies: investor_id, company_id +- investors: id, name, aum +- team_members: id, company_id, name, title + +IMPORTANT RULES: +1. ALWAYS return ONLY company IDs (companies.id) - use SELECT DISTINCT c.id +2. For industry: Check BOTH industry field AND sectors table with synonyms + - Use LEFT JOIN for sectors so companies without sector tags still match + - Include related terms: 'Fintech' → c.industry LIKE '%Fintech%' OR c.industry LIKE '%Finance%' OR sec.name LIKE '%Fintech%' OR sec.name LIKE '%Financial%' + - 'AI' → c.industry LIKE '%AI%' OR c.industry LIKE '%Artificial Intelligence%' OR c.industry LIKE '%Machine Learning%' OR sec.name LIKE '%AI%' OR sec.name LIKE '%ML%' +3. For location: Be FLEXIBLE with variations and abbreviations + - 'San Francisco' → c.location LIKE '%San Francisco%' OR c.location LIKE '%SF%' OR c.location LIKE '%Bay Area%' + - 'New York' → c.location LIKE '%New York%' OR c.location LIKE '%NYC%' OR c.location LIKE '%NY%' + - 'Europe' → c.location LIKE '%Europe%' OR c.location LIKE '%UK%' OR c.location LIKE '%London%' OR c.location LIKE '%Berlin%' OR c.location LIKE '%Paris%' +4. For sectors: Use LEFT JOIN and include multiple synonyms + - 'Healthcare' → sec.name LIKE '%Healthcare%' OR sec.name LIKE '%Health%' OR sec.name LIKE '%Medical%' OR sec.name LIKE '%Biotech%' OR c.industry LIKE '%Health%' +5. For founding year filters (include NULL to be inclusive): + - "founded after 2020" → WHERE (founded_year >= 2020 OR founded_year IS NULL) + - "founded before 2018" → WHERE (founded_year <= 2018 OR founded_year IS NULL) + - "founded in 2020" → WHERE founded_year = 2020 +6. For investor-related queries: Use JOIN investor_companies +7. Use LEFT JOIN for sectors so companies without tags still match +8. Use DISTINCT to avoid duplicates from joins +9. Be INCLUSIVE - use OR conditions with synonyms and variations +10. Return a single, complete SELECT query + +Example Queries: +Q: "Fintech companies founded in 2020" +A: SELECT DISTINCT c.id FROM companies c + LEFT JOIN company_sector cs ON c.id = cs.company_id + LEFT JOIN sectors sec ON cs.sector_id = sec.id + WHERE (c.industry LIKE '%Fintech%' OR c.industry LIKE '%Finance%' OR c.industry LIKE '%Financial%' OR sec.name LIKE '%Fintech%' OR sec.name LIKE '%Financial Services%') + AND c.founded_year = 2020 + +Q: "AI companies in San Francisco" +A: SELECT DISTINCT c.id FROM companies c + LEFT JOIN company_sector cs ON c.id = cs.company_id + LEFT JOIN sectors sec ON cs.sector_id = sec.id + WHERE (c.industry LIKE '%AI%' OR c.industry LIKE '%Artificial Intelligence%' OR c.industry LIKE '%Machine Learning%' OR sec.name LIKE '%AI%' OR sec.name LIKE '%Machine Learning%' OR sec.name LIKE '%ML%') + AND (c.location LIKE '%San Francisco%' OR c.location LIKE '%SF%' OR c.location LIKE '%Bay Area%') + +Q: "Healthcare companies" +A: SELECT DISTINCT c.id FROM companies c + LEFT JOIN company_sector cs ON c.id = cs.company_id + LEFT JOIN sectors sec ON cs.sector_id = sec.id + WHERE c.industry LIKE '%Healthcare%' OR c.industry LIKE '%Health%' OR c.industry LIKE '%Medical%' OR sec.name LIKE '%Healthcare%' OR sec.name LIKE '%Medical%' OR sec.name LIKE '%Biotech%' OR sec.name LIKE '%Pharma%' + +Q: "Companies funded by Sequoia" +A: SELECT DISTINCT c.id FROM companies c + JOIN investor_companies ic ON c.id = ic.company_id + JOIN investors i ON ic.investor_id = i.id + WHERE i.name LIKE '%Sequoia%' + +Q: "European startups founded after 2019" +A: SELECT DISTINCT c.id FROM companies c + WHERE (c.location LIKE '%Europe%' OR c.location LIKE '%UK%' OR c.location LIKE '%London%' OR c.location LIKE '%Germany%' OR c.location LIKE '%Berlin%' OR c.location LIKE '%France%' OR c.location LIKE '%Paris%') + AND (c.founded_year > 2019 OR c.founded_year IS NULL) + +Q: "SaaS companies" +A: SELECT DISTINCT c.id FROM companies c + LEFT JOIN company_sector cs ON c.id = cs.company_id + LEFT JOIN sectors sec ON cs.sector_id = sec.id + WHERE c.industry LIKE '%SaaS%' OR c.industry LIKE '%Software%' OR c.industry LIKE '%Cloud%' OR sec.name LIKE '%SaaS%' OR sec.name LIKE '%Software%' + +IMPORTANT: +- Use LEFT JOIN so companies without sector tags still match via industry field +- Use OR conditions with related keywords/synonyms to cast a wider net +- Include NULL checks for optional filters to avoid excluding companies with missing data + +Return ONLY the SQL query, no explanations or markdown.""", + ), + ("user", "{question}"), + ] ) + def _get_cache_key(self, question: str) -> str: + """Generate cache key from normalized question.""" + return hashlib.md5(question.lower().strip().encode()).hexdigest() + def process_query(self, question: str) -> PaginatedResponse[CompanyData]: - """Process a query using the LLM and return company response data. + """Process a query by generating and executing SQL directly. Args: question: The natural language query to process """ - # Let the LLM handle all database interactions and filtering to get company IDs - response = self.agent.invoke( - {"messages": [("user", question)]}, - config={"recursion_limit": 50}, - ) + cache_key = self._get_cache_key(question) - # Extract the actual message content - logger.info(f"{response}") + # Check cache first + if cache_key in self.query_cache: + sql_query = self.query_cache[cache_key] + logger.info(f"Using cached SQL: {sql_query}") + else: + # Generate SQL query + messages = self.sql_prompt.format_messages(question=question) + response = self.llm.invoke(messages) + sql_query = response.content.strip() - # Look through all messages to find the SQL query results (ToolMessage with actual data) - company_ids = [] - for message in response["messages"]: - if hasattr(message, "content") and message.content: - # Check if this looks like SQL results (contains tuples with numbers) - if "(" in str(message.content) and "," in str(message.content): - company_ids = self._extract_company_ids_from_response( - str(message.content) - ) - if company_ids: - logger.info( - f"Extracted {len(company_ids)} company IDs from results" - ) - break + # Clean up SQL (remove markdown code blocks if present) + sql_query = sql_query.replace("```sql", "").replace("```", "").strip() - # If no IDs found from ToolMessage, check the final AI message - if not company_ids: - final_message_content = response["messages"][-1].content - logger.info(f"AI Response: \n{final_message_content}") - company_ids = self._extract_company_ids_from_response(final_message_content) - - # Fetch full company data with relationships using the IDs - return self._fetch_companies_by_ids(company_ids) - - def _extract_company_ids_from_response(self, ai_response: str) -> List[int]: - """Extract company IDs from AI response.""" - import re - - company_ids = [] - - # Check if response is NO_RESULTS - if "NO_RESULTS" in ai_response.upper(): - return [] + # Cache the query + self.query_cache[cache_key] = sql_query + logger.info(f"Generated SQL: {sql_query}") + # Execute query to get company IDs + db_session = next(get_db()) try: - # The response contains tuples like (1,), (5,), etc. - # Extract numbers between parentheses - pattern = r"\((\d+),?\)" - matches = re.findall(pattern, ai_response) - if matches: - company_ids = [int(match) for match in matches] - else: - # Fallback: extract all numbers - numbers = re.findall(r"\b\d+\b", ai_response) - # Filter out very large numbers that might be tokens or timestamps - company_ids = [int(num) for num in numbers if int(num) < 100000] + result = db_session.execute(text(sql_query)) + company_ids = [row[0] for row in result.fetchall()] + logger.info( + f"Found {len(company_ids)} company IDs: {company_ids[:10]}{'...' if len(company_ids) > 10 else ''}" + ) + return self._fetch_companies_by_ids(company_ids) except Exception as e: - logger.error(f"Error extracting IDs from response: {e}") - return [] - - return company_ids + logger.error(f"SQL execution error: {e}") + logger.error(f"Failed SQL: {sql_query}") + # Return empty result + return PaginatedResponse( + items=[], total=0, page=1, page_size=10, total_pages=0 + ) + finally: + db_session.close() def _fetch_companies_by_ids( self, company_ids: List[int] @@ -130,7 +174,7 @@ class CompanyQueryProcessor: items=[], total=0, page=1, - page_size=len(company_ids) if company_ids else 10, + page_size=10, total_pages=0, ) diff --git a/app/services/querying.py b/app/services/querying.py index 5bd0219..2a566eb 100644 --- a/app/services/querying.py +++ b/app/services/querying.py @@ -1,29 +1,24 @@ -import json +import hashlib import logging import os from typing import List, Optional -from db.db import DATABASE_URL, get_db +from db.db import get_db from db.models import FundTable, InvestorTable, ProjectTable -from langchain import hub -from langchain_community.agent_toolkits import SQLDatabaseToolkit -from langchain_community.utilities import SQLDatabase +from langchain_core.prompts import ChatPromptTemplate from langchain_openai import ChatOpenAI -from langgraph.prebuilt import create_react_agent from schemas.router_schemas import ( CompanyMinimal, InvestmentResponse, PaginatedResponse, SectorMinimal, ) +from sqlalchemy import text from sqlalchemy.orm import selectinload from services.compatibility_score import calculate_project_investor_compatibility logger = logging.getLogger(__name__) -# Connect to SQLite -prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt") -db = SQLDatabase.from_uri(DATABASE_URL) class QueryProcessor: @@ -34,89 +29,150 @@ class QueryProcessor: model="openai/gpt-4o-mini", temperature=0, ) - self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm) - # Update system message to specifically request only fund IDs - system_message_updated = ( - prompt_template.format(dialect="SQLite", top_k=100) - + "\n\n=== IMPORTANT TERMINOLOGY ===" - + "\n- When users say 'investors' or 'find me investors', they mean FUNDS" - + "\n- Always query the 'funds' table for investment opportunities" - + "\n- The 'investors' table is for parent company information only" - + "\n- Relationship: investors (1) -> (many) funds" - + "\n\n=== YOUR TASK ===" - + "\nReturn ONLY fund IDs (funds.id) that match the user's criteria." - + "\nFormat: comma-separated numbers only (e.g., 1, 5, 12, 23)" - + "\nNo explanations, no other data." - + "\n\n=== QUERY GUIDELINES ===" - + "\n1. For geographic searches: use funds.geographic_focus" - + "\n2. For sector searches: JOIN with fund_sectors table" - + "\n3. For stage searches: JOIN with fund_investment_stages table" - + "\n4. Return ALL matching fund IDs, not just the first few" - + "\n5. If no results: respond with 'NO_RESULTS'" - + "\n6. Never repeat the same failed query" - + "\n\n=== GEOGRAPHIC SEARCH RULES (VERY IMPORTANT) ===" - + "\n- ALWAYS use LIKE '%keyword%' for geographic searches, NEVER use exact equality (=)" - + "\n- When user says 'Europe', match ANY location containing 'Europe' (e.g., 'Northern Europe', 'Western Europe', 'Europe', 'Central Europe')" - + "\n- When user says 'America', match locations like 'North America', 'South America', 'Latin America', 'United States'" - + "\n- When user says 'Asia', match 'Asia', 'Southeast Asia', 'East Asia', etc." - + "\n- Examples:" - + "\n * User: 'Europe' → SQL: WHERE geographic_focus LIKE '%Europe%'" - + "\n * User: 'America' → SQL: WHERE geographic_focus LIKE '%America%'" - + "\n * User: 'UK' → SQL: WHERE geographic_focus LIKE '%UK%' OR geographic_focus LIKE '%United Kingdom%'" - + "\n- Be INCLUSIVE: capture all relevant regional variations" - ) - self.agent = create_react_agent( - model=self.llm, - tools=self.toolkit.get_tools(), - prompt=system_message_updated, + + # Query cache for performance + self.query_cache = {} + + # SQL generation prompt + self.sql_prompt = ChatPromptTemplate.from_messages( + [ + ( + "system", + """You are a SQL expert. Generate a SQLite query to find fund IDs based on user requirements. + +Database Schema: +- funds: id, fund_name, investor_id, check_size_lower, check_size_upper, geographic_focus +- fund_sectors: fund_id, sector_id +- fund_investment_stages: fund_id, stage_id +- sectors: id, name +- investment_stages: id, name +- investors: id, name, aum + +IMPORTANT RULES: +1. ALWAYS return ONLY fund IDs (funds.id) - use SELECT DISTINCT f.id +2. For geography: Be FLEXIBLE - use OR with variations and partial matches + - 'Europe' → WHERE geographic_focus LIKE '%Europe%' OR geographic_focus LIKE '%European%' + - 'America' → WHERE geographic_focus LIKE '%America%' OR geographic_focus LIKE '%US%' OR geographic_focus LIKE '%United States%' + - 'Asia' → WHERE geographic_focus LIKE '%Asia%' OR geographic_focus LIKE '%Asian%' + - If no geography specified, DON'T filter by geography +3. For stages: Use LEFT JOIN and LIKE for flexible matching with synonyms + - 'Seed' → s.name LIKE '%Seed%' OR s.name LIKE '%Pre-Seed%' OR s.name LIKE '%Early%' + - 'Series A' → s.name LIKE '%Series A%' OR s.name LIKE '%A%' + - 'Growth' → s.name LIKE '%Growth%' OR s.name LIKE '%Late%' OR s.name LIKE '%Expansion%' + - If stage not specified, include ALL funds +4. For sectors: Use LEFT JOIN and include related terms with OR + - 'Fintech' → sec.name LIKE '%Fintech%' OR sec.name LIKE '%Finance%' OR sec.name LIKE '%Financial%' + - 'AI' → sec.name LIKE '%AI%' OR sec.name LIKE '%Artificial Intelligence%' OR sec.name LIKE '%Machine Learning%' OR sec.name LIKE '%ML%' + - 'Healthcare' → sec.name LIKE '%Healthcare%' OR sec.name LIKE '%Health%' OR sec.name LIKE '%Medical%' OR sec.name LIKE '%Biotech%' +5. For check size filters (be flexible with ranges): + - "under X" → WHERE (check_size_upper <= X OR check_size_upper IS NULL) + - "over X" → WHERE (check_size_lower >= X OR check_size_lower IS NULL) + - "between X and Y" → WHERE check_size_lower >= X AND check_size_upper <= Y +6. Use LEFT JOIN for stages and sectors so funds without tags still match +7. Use DISTINCT to avoid duplicates from joins +8. Be INCLUSIVE - use OR conditions to cast a wider net +9. If query is very simple (e.g., just "seed stage"), don't add unnecessary filters +10. Return a single, complete SELECT query + +Example Queries: +Q: "Seed stage investors in Europe" +A: SELECT DISTINCT f.id FROM funds f + LEFT JOIN fund_investment_stages fis ON f.id = fis.fund_id + LEFT JOIN investment_stages s ON fis.stage_id = s.id + WHERE (s.name LIKE '%Seed%' OR s.name LIKE '%Pre-Seed%' OR s.name LIKE '%Early%' OR s.id IS NULL) + AND (f.geographic_focus LIKE '%Europe%' OR f.geographic_focus LIKE '%European%') + +Q: "Fintech investors with check size under 5 million" +A: SELECT DISTINCT f.id FROM funds f + LEFT JOIN fund_sectors fs ON f.id = fs.fund_id + LEFT JOIN sectors sec ON fs.sector_id = sec.id + WHERE (sec.name LIKE '%Fintech%' OR sec.name LIKE '%Finance%' OR sec.name LIKE '%Financial%' OR sec.id IS NULL) + AND (f.check_size_upper <= 5000000 OR f.check_size_upper IS NULL) + +Q: "Seed stage investors" +A: SELECT DISTINCT f.id FROM funds f + LEFT JOIN fund_investment_stages fis ON f.id = fis.fund_id + LEFT JOIN investment_stages s ON fis.stage_id = s.id + WHERE s.name LIKE '%Seed%' OR s.name LIKE '%Pre-Seed%' OR s.name LIKE '%Early%' + +Q: "Growth stage investors" +A: SELECT DISTINCT f.id FROM funds f + LEFT JOIN fund_investment_stages fis ON f.id = fis.fund_id + LEFT JOIN investment_stages s ON fis.stage_id = s.id + WHERE s.name LIKE '%Growth%' OR s.name LIKE '%Late%' OR s.name LIKE '%Expansion%' OR s.name LIKE '%Series C%' OR s.name LIKE '%Series D%' + +Q: "AI investors in America" +A: SELECT DISTINCT f.id FROM funds f + LEFT JOIN fund_sectors fs ON f.id = fs.fund_id + LEFT JOIN sectors sec ON fs.sector_id = sec.id + WHERE (sec.name LIKE '%AI%' OR sec.name LIKE '%Artificial Intelligence%' OR sec.name LIKE '%Machine Learning%' OR sec.name LIKE '%ML%') + AND (f.geographic_focus LIKE '%America%' OR f.geographic_focus LIKE '%US%' OR f.geographic_focus LIKE '%United States%' OR f.geographic_focus LIKE '%USA%') + +Q: "Healthcare investors" +A: SELECT DISTINCT f.id FROM funds f + LEFT JOIN fund_sectors fs ON f.id = fs.fund_id + LEFT JOIN sectors sec ON fs.sector_id = sec.id + WHERE sec.name LIKE '%Healthcare%' OR sec.name LIKE '%Health%' OR sec.name LIKE '%Medical%' OR sec.name LIKE '%Biotech%' OR sec.name LIKE '%Pharma%' + +IMPORTANT: Use LEFT JOIN so funds without sector/stage tags can still match. Include synonym terms with OR for better recall. + +Return ONLY the SQL query, no explanations or markdown.""", + ), + ("user", "{question}"), + ] ) + def _get_cache_key(self, question: str) -> str: + """Generate cache key from normalized question.""" + return hashlib.md5(question.lower().strip().encode()).hexdigest() + def process_query( self, question: str, project_id: Optional[int] = None ) -> PaginatedResponse[InvestmentResponse]: - """Process a query using the LLM and return investment response data. + """Process a query by generating and executing SQL directly. Args: question: The natural language query to process project_id: Optional project ID for compatibility scoring """ - # Let the LLM handle all database interactions and filtering to get fund IDs - response = self.agent.invoke( - {"messages": [("user", question)]}, - config={"recursion_limit": 50}, - ) + cache_key = self._get_cache_key(question) - # Extract the actual message content - logger.info(f"{response}") - final_message_content = response["messages"][-1].content - logger.info(f"AI Response: \n{final_message_content}") - # Extract fund IDs from the AI response - fund_ids = self._extract_fund_ids_from_response(final_message_content) + # Check cache first + if cache_key in self.query_cache: + sql_query = self.query_cache[cache_key] + logger.info(f"Using cached SQL: {sql_query}") + else: + # Generate SQL query + messages = self.sql_prompt.format_messages(question=question) + response = self.llm.invoke(messages) + sql_query = response.content.strip() - # Fetch full fund data with investor relationships using the IDs - return self._fetch_funds_by_ids(fund_ids, project_id) + # Clean up SQL (remove markdown code blocks if present) + sql_query = sql_query.replace("```sql", "").replace("```", "").strip() - def _extract_fund_ids_from_response(self, ai_response: str) -> List[int]: - """Extract fund IDs from AI response.""" - import re + # Cache the query + self.query_cache[cache_key] = sql_query + logger.info(f"Generated SQL: {sql_query}") - fund_ids = [] + # Execute query to get fund IDs + db_session = next(get_db()) try: - # Try multiple patterns to extract IDs from the response - # Pattern 1: Simple numbers (assuming they are IDs) - numbers = re.findall(r"\b\d+\b", ai_response) - fund_ids = [int(num) for num in numbers] - - # Pattern 2: If response contains explicit ID references - id_matches = re.findall(r"\bid[:\s]*(\d+)", ai_response.lower()) - if id_matches: - fund_ids = [int(id_str) for id_str in id_matches] + result = db_session.execute(text(sql_query)) + fund_ids = [row[0] for row in result.fetchall()] + logger.info( + f"Found {len(fund_ids)} fund IDs: {fund_ids[:10]}{'...' if len(fund_ids) > 10 else ''}" + ) + return self._fetch_funds_by_ids(fund_ids, project_id) except Exception as e: - print(f"Error extracting IDs from response: {e}") - return [] - - return fund_ids + logger.error(f"SQL execution error: {e}") + logger.error(f"Failed SQL: {sql_query}") + # Return empty result + return PaginatedResponse( + items=[], total=0, page=1, page_size=10, total_pages=0 + ) + finally: + db_session.close() def _fetch_funds_by_ids( self, fund_ids: List[int], project_id: Optional[int] = None