@@ -363,6 +363,40 @@ def test_func_deterministic_keyword_only(self):
363363 with self .assertRaises (TypeError ):
364364 self .con .create_function ("deterministic" , 0 , int , True )
365365
366+ @unittest .skipIf (sqlite .sqlite_version_info < (3 , 31 , 0 ),
367+ "Requires SQLite 3.31.0 or higher" )
368+ def test_func_non_innocuous_in_trusted_env (self ):
369+ mock = Mock (return_value = None )
370+ self .con .create_function ("noninnocuous" , 0 , mock , innocuous = False )
371+ self .con .execute ("pragma trusted_schema = 0" )
372+ self .con .execute ("create view notallowed as select noninnocuous() = noninnocuous()" )
373+ with self .assertRaises (sqlite .OperationalError ) as cm :
374+ self .con .execute ("select * from notallowed" )
375+ self .assertEqual (str (cm .exception ), 'unsafe use of noninnocuous()' )
376+
377+ @unittest .skipIf (sqlite .sqlite_version_info < (3 , 31 , 0 ),
378+ "Requires SQLite 3.31.0 or higher" )
379+ def test_func_innocuous_in_trusted_env (self ):
380+ mock = Mock (return_value = None )
381+ self .con .create_function ("innocuous" , 0 , mock , innocuous = True )
382+ self .con .execute ("pragma trusted_schema = 0" )
383+ self .con .execute ("create view allowed as select innocuous() = innocuous()" )
384+ self .con .execute ("select * from allowed" )
385+ self .assertEqual (mock .call_count , 2 )
386+
387+ @unittest .skipIf (sqlite .sqlite_version_info < (3 , 31 , 0 ),
388+ "Requires SQLite 3.31.0 or higher" )
389+ def test_func_direct_only (self ):
390+ mock = Mock (return_value = None )
391+ self .con .create_function ("directonly" , 0 , mock , directonly = True )
392+ self .con .execute ("pragma trusted_schema = 1" )
393+ self .con .execute ("select directonly() = directonly()" )
394+ self .assertEqual (mock .call_count , 2 )
395+ self .con .execute ("create view notallowed as select directonly() = directonly()" )
396+ with self .assertRaises (sqlite .OperationalError ) as cm :
397+ self .con .execute ("select * from notallowed" )
398+ self .assertEqual (str (cm .exception ), 'unsafe use of directonly()' )
399+
366400 def test_function_destructor_via_gc (self ):
367401 # See bpo-44304: The destructor of the user function can
368402 # crash if is called without the GIL from the gc functions
@@ -479,6 +513,9 @@ def setUp(self):
479513 from test order by x
480514 """
481515 self .con .create_window_function ("sumint" , 1 , WindowSumInt )
516+ if sqlite .sqlite_version_info >= (3 , 31 , 0 ):
517+ self .con .create_window_function ("sumintInnocuous" , 1 , WindowSumInt , innocuous = True )
518+ self .con .create_window_function ("sumintDirectOnly" , 1 , WindowSumInt , directonly = True )
482519
483520 def tearDown (self ):
484521 self .cur .close ()
@@ -488,6 +525,34 @@ def test_win_sum_int(self):
488525 self .cur .execute (self .query % "sumint" )
489526 self .assertEqual (self .cur .fetchall (), self .expected )
490527
528+ @unittest .skipIf (sqlite .sqlite_version_info < (3 , 31 , 0 ),
529+ "Requires SQLite 3.31.0 or newer" )
530+ def test_win_non_innocuous (self ):
531+ self .cur .execute ("pragma trusted_schema = 0" )
532+ self .cur .execute ("create view notallowed as " + self .query % "sumint" )
533+ with self .assertRaises (sqlite .OperationalError ) as cm :
534+ self .cur .execute ("select * from notallowed" )
535+ self .assertEqual (str (cm .exception ), 'unsafe use of sumint()' )
536+
537+ @unittest .skipIf (sqlite .sqlite_version_info < (3 , 31 , 0 ),
538+ "Requires SQLite 3.31.0 or newer" )
539+ def test_win_innocuous (self ):
540+ self .cur .execute ("pragma trusted_schema = 0" )
541+ self .cur .execute ("create view allowed as " + self .query % "sumintInnocuous" )
542+ self .cur .execute ("select * from allowed" )
543+ self .assertEqual (self .cur .fetchall (), self .expected )
544+
545+ @unittest .skipIf (sqlite .sqlite_version_info < (3 , 31 , 0 ),
546+ "Requires SQLite 3.31.0 or newer" )
547+ def test_win_directonly (self ):
548+ self .cur .execute ("pragma trusted_schema = 1" )
549+ self .cur .execute ("create view notallowed as " + self .query % "sumintDirectOnly" )
550+ with self .assertRaises (sqlite .OperationalError ) as cm :
551+ self .cur .execute ("select * from notallowed" )
552+ self .assertEqual (str (cm .exception ), 'unsafe use of sumintDirectOnly()' )
553+ self .cur .execute (self .query % "sumintDirectOnly" )
554+ self .assertEqual (self .cur .fetchall (), self .expected )
555+
491556 def test_win_error_on_create (self ):
492557 with self .assertRaisesRegex (sqlite .ProgrammingError , "not -100" ):
493558 self .con .create_window_function ("shouldfail" , - 100 , WindowSumInt )
@@ -614,6 +679,9 @@ def setUp(self):
614679 self .con .create_aggregate ("checkTypes" , - 1 , AggrCheckTypes )
615680 self .con .create_aggregate ("mysum" , 1 , AggrSum )
616681 self .con .create_aggregate ("aggtxt" , 1 , AggrText )
682+ if sqlite .sqlite_version_info >= (3 , 31 , 0 ):
683+ self .con .create_aggregate ("mysumInnocuous" , 1 , AggrSum , innocuous = True )
684+ self .con .create_aggregate ("mysumDirectOnly" , 1 , AggrSum , directonly = True )
617685
618686 def tearDown (self ):
619687 self .con .close ()
@@ -705,6 +773,45 @@ def test_aggr_check_aggr_sum(self):
705773 val = cur .fetchone ()[0 ]
706774 self .assertEqual (val , 60 )
707775
776+ @unittest .skipIf (sqlite .sqlite_version_info < (3 , 31 , 0 ),
777+ "Requires SQLite 3.31.0 or newer" )
778+ def test_aggr_non_innocuous (self ):
779+ cur = self .con .cursor ()
780+ cur .execute ("pragma trusted_schema = 0" )
781+ cur .execute ("delete from test" )
782+ cur .execute ("insert into test(i) values (?)" , (10 ,))
783+ cur .execute ("create view notallowed as select mysum(i) from test" )
784+ with self .assertRaises (sqlite .OperationalError ) as cm :
785+ cur .execute ("select * from notallowed" )
786+ self .assertEqual (str (cm .exception ), 'unsafe use of mysum()' )
787+
788+ @unittest .skipIf (sqlite .sqlite_version_info < (3 , 31 , 0 ),
789+ "Requires SQLite 3.31.0 or newer" )
790+ def test_aggr_innocuous (self ):
791+ cur = self .con .cursor ()
792+ cur .execute ("pragma trusted_schema = 0" )
793+ cur .execute ("delete from test" )
794+ cur .executemany ("insert into test(i) values (?)" , [(10 ,), (20 ,), (30 ,)])
795+ cur .execute ("create view allowed as select mysumInnocuous(i) from test" )
796+ cur .execute ("select * from allowed" )
797+ val = cur .fetchone ()[0 ]
798+ self .assertEqual (val , 60 )
799+
800+ @unittest .skipIf (sqlite .sqlite_version_info < (3 , 31 , 0 ),
801+ "Requires SQLite 3.31.0 or newer" )
802+ def test_aggr_directonly (self ):
803+ cur = self .con .cursor ()
804+ cur .execute ("pragma trusted_schema = 1" )
805+ cur .execute ("delete from test" )
806+ cur .executemany ("insert into test(i) values (?)" , [(10 ,), (20 ,), (30 ,)])
807+ cur .execute ("create view notallowed as select mysumDirectOnly(i) from test" )
808+ with self .assertRaises (sqlite .OperationalError ) as cm :
809+ cur .execute ("select * from notallowed" )
810+ self .assertEqual (str (cm .exception ), 'unsafe use of mysumDirectOnly()' )
811+ cur .execute ("select mysumDirectOnly(i) from test" )
812+ val = cur .fetchone ()[0 ]
813+ self .assertEqual (val , 60 )
814+
708815 def test_aggr_no_match (self ):
709816 cur = self .con .execute ("select mysum(i) from (select 1 as i) where i == 0" )
710817 val = cur .fetchone ()[0 ]
0 commit comments