1+ package org .springdoc .core .customisers ;
2+
3+ import java .lang .reflect .Field ;
4+ import java .lang .reflect .Modifier ;
5+ import java .lang .reflect .Type ;
6+ import java .util .ArrayList ;
7+ import java .util .Arrays ;
8+ import java .util .Collections ;
9+ import java .util .List ;
10+ import java .util .Map ;
11+ import java .util .Optional ;
12+ import java .util .Set ;
13+ import java .util .stream .Collectors ;
14+
15+ import com .querydsl .core .types .Path ;
16+ import io .swagger .v3 .core .converter .ModelConverters ;
17+ import io .swagger .v3 .core .converter .ResolvedSchema ;
18+ import io .swagger .v3 .core .util .PrimitiveType ;
19+ import io .swagger .v3 .oas .models .Operation ;
20+ import io .swagger .v3 .oas .models .media .Schema ;
21+ import io .swagger .v3 .oas .models .parameters .Parameter ;
22+ import org .apache .commons .lang3 .StringUtils ;
23+ import org .slf4j .Logger ;
24+ import org .slf4j .LoggerFactory ;
25+ import org .springdoc .core .customizers .OperationCustomizer ;
26+
27+ import org .springframework .core .LocalVariableTableParameterNameDiscoverer ;
28+ import org .springframework .core .MethodParameter ;
29+ import org .springframework .data .querydsl .binding .QuerydslBinderCustomizer ;
30+ import org .springframework .data .querydsl .binding .QuerydslBindings ;
31+ import org .springframework .data .querydsl .binding .QuerydslBindingsFactory ;
32+ import org .springframework .data .querydsl .binding .QuerydslPredicate ;
33+ import org .springframework .data .util .CastUtils ;
34+ import org .springframework .data .util .ClassTypeInformation ;
35+ import org .springframework .data .util .TypeInformation ;
36+ import org .springframework .web .method .HandlerMethod ;
37+
38+ /**
39+ * @author Gibah Joseph
40+ * Email: gibahjoe@gmail.com
41+ * Mar, 2020
42+ **/
43+ public class QuerydslPredicateOperationCustomizer implements OperationCustomizer {
44+ private static final Logger LOGGER = LoggerFactory .getLogger (QuerydslPredicateOperationCustomizer .class );
45+ private QuerydslBindingsFactory querydslBindingsFactory ;
46+ private LocalVariableTableParameterNameDiscoverer localVariableTableParameterNameDiscoverer ;
47+
48+ public QuerydslPredicateOperationCustomizer (QuerydslBindingsFactory querydslBindingsFactory , LocalVariableTableParameterNameDiscoverer localVariableTableParameterNameDiscoverer ) {
49+ this .querydslBindingsFactory = querydslBindingsFactory ;
50+ this .localVariableTableParameterNameDiscoverer = localVariableTableParameterNameDiscoverer ;
51+ }
52+
53+ @ Override
54+ public Operation customize (Operation operation , HandlerMethod handlerMethod ) {
55+ if (operation .getParameters () == null ) {
56+ return operation ;
57+ }
58+
59+ MethodParameter [] methodParameters = handlerMethod .getMethodParameters ();
60+ String [] methodParameterNames = this .localVariableTableParameterNameDiscoverer .getParameterNames (handlerMethod .getMethod ());
61+ String [] reflectionParametersNames = Arrays .stream (methodParameters ).map (MethodParameter ::getParameterName ).toArray (String []::new );
62+ if (methodParameterNames == null ) {
63+ methodParameterNames = reflectionParametersNames ;
64+ }
65+ int parametersLength = methodParameters .length ;
66+ List <Parameter > parametersToAddToOperation = new ArrayList <>();
67+ for (int i = 0 ; i < parametersLength ; i ++) {
68+ MethodParameter parameter = methodParameters [i ];
69+ QuerydslPredicate predicate = parameter .getParameterAnnotation (QuerydslPredicate .class );
70+
71+ if (predicate == null ) {
72+ continue ;
73+ }
74+
75+ List <io .swagger .v3 .oas .models .parameters .Parameter > operationParameters = operation .getParameters ();
76+ QuerydslBindings bindings = extractQdslBindings (predicate );
77+
78+ Set <String > fieldsToAdd = Arrays .stream (predicate .root ().getDeclaredFields ()).map (Field ::getName ).collect (Collectors .toSet ());
79+
80+ Map <String , Object > pathSpecMap = getPathSpec (bindings , "pathSpecs" );
81+ //remove blacklisted fields
82+ Set <String > blacklist = getFieldValues (bindings , "blackList" );
83+ fieldsToAdd .removeIf (blacklist ::contains );
84+
85+ Set <String > whiteList = getFieldValues (bindings , "whiteList" );
86+ Set <String > aliases = getFieldValues (bindings , "aliases" );
87+
88+ fieldsToAdd .addAll (aliases );
89+ fieldsToAdd .addAll (whiteList );
90+ for (String fieldName : fieldsToAdd ) {
91+ Type type = getFieldType (fieldName , pathSpecMap , predicate .root ());
92+ io .swagger .v3 .oas .models .parameters .Parameter newParameter = buildParam (type , fieldName );
93+
94+ parametersToAddToOperation .add (newParameter );
95+ }
96+ }
97+ operation .getParameters ().addAll (parametersToAddToOperation );
98+ return operation ;
99+ }
100+
101+ private QuerydslBindings extractQdslBindings (QuerydslPredicate predicate ) {
102+ ClassTypeInformation <?> classTypeInformation = ClassTypeInformation .from (predicate .root ());
103+ TypeInformation <?> domainType = classTypeInformation .getRequiredActualType ();
104+
105+ Optional <Class <? extends QuerydslBinderCustomizer <?>>> bindingsAnnotation = Optional .of (predicate )
106+ .map (QuerydslPredicate ::bindings )
107+ .map (CastUtils ::cast );
108+
109+ return bindingsAnnotation
110+ .map (it -> querydslBindingsFactory .createBindingsFor (domainType , it ))
111+ .orElseGet (() -> querydslBindingsFactory .createBindingsFor (domainType ));
112+ }
113+
114+ private Set <String > getFieldValues (QuerydslBindings instance , String fieldName ) {
115+ try {
116+ Field field = instance .getClass ().getDeclaredField (fieldName );
117+ if (Modifier .isPrivate (field .getModifiers ())) {
118+ field .setAccessible (true );
119+ }
120+ return (Set <String >) field .get (instance );
121+ } catch (NoSuchFieldException | IllegalAccessException e ) {
122+ LOGGER .warn ("NoSuchFieldException or IllegalAccessException occurred : {}" , e .getMessage ());
123+ }
124+ return Collections .emptySet ();
125+ }
126+
127+ private Map <String , Object > getPathSpec (QuerydslBindings instance , String fieldName ) {
128+ try {
129+ Field field = instance .getClass ().getDeclaredField (fieldName );
130+ if (Modifier .isPrivate (field .getModifiers ())) {
131+ field .setAccessible (true );
132+ }
133+ return (Map <String , Object >) field .get (instance );
134+ } catch (NoSuchFieldException | IllegalAccessException e ) {
135+ LOGGER .warn ("NoSuchFieldException or IllegalAccessException occurred : {}" , e .getMessage ());
136+ }
137+ return Collections .emptyMap ();
138+ }
139+
140+ private Optional <Path <?>> getPathFromPathSpec (Object instance ) {
141+ try {
142+ if (instance == null ) {
143+ return Optional .empty ();
144+ }
145+ Field field = instance .getClass ().getDeclaredField ("path" );
146+ if (Modifier .isPrivate (field .getModifiers ())) {
147+ field .setAccessible (true );
148+ }
149+ return (Optional <Path <?>>) field .get (instance );
150+ } catch (NoSuchFieldException | IllegalAccessException e ) {
151+ LOGGER .warn ("NoSuchFieldException or IllegalAccessException occurred : {}" , e .getMessage ());
152+ }
153+ return Optional .empty ();
154+ }
155+
156+ /***
157+ * Tries to figure out the Type of the field. It first checks the Qdsl pathSpecMap before checking the root class. Defaults to String.class
158+ * @param fieldName The name of the field used as reference to get the type
159+ * @param pathSpecMap The Qdsl path specifications as defined in the resolved bindings
160+ * @param root The root type where the paths are gotten
161+ * @return The type of the field. Returns
162+ */
163+ private Type getFieldType (String fieldName , Map <String , Object > pathSpecMap , Class <?> root ) {
164+ try {
165+ Object pathAndBinding = pathSpecMap .get (fieldName );
166+ Optional <Path <?>> path = getPathFromPathSpec (pathAndBinding );
167+
168+ Type genericType ;
169+ Field declaredField = null ;
170+ if (path .isPresent ()) {
171+ genericType = path .get ().getType ();
172+ } else {
173+ declaredField = root .getDeclaredField (fieldName );
174+ genericType = declaredField .getGenericType ();
175+ }
176+ if (genericType != null ) {
177+ return genericType ;
178+ }
179+ } catch (NoSuchFieldException e ) {
180+ LOGGER .warn ("Field {} not found on {} : {}" , fieldName , root .getName (), e .getMessage ());
181+ }
182+ return String .class ;
183+ }
184+
185+ /***
186+ * Constructs the parameter
187+ * @param type The type of the parameter
188+ * @param name The name of the parameter
189+ * @return The swagger parameter
190+ */
191+ private io .swagger .v3 .oas .models .parameters .Parameter buildParam (Type type , String name ) {
192+ io .swagger .v3 .oas .models .parameters .Parameter parameter = new io .swagger .v3 .oas .models .parameters .Parameter ();
193+
194+ if (StringUtils .isBlank (parameter .getName ())) {
195+ parameter .setName (name );
196+ }
197+
198+ if (StringUtils .isBlank (parameter .getIn ())) {
199+ parameter .setIn ("query" );
200+ }
201+
202+ if (parameter .getSchema () == null ) {
203+ Schema <?> schema = null ;
204+ PrimitiveType primitiveType = PrimitiveType .fromType (type );
205+ if (primitiveType != null ) {
206+ schema = primitiveType .createProperty ();
207+ } else {
208+ ResolvedSchema resolvedSchema = ModelConverters .getInstance ()
209+ .resolveAsResolvedSchema (
210+ new io .swagger .v3 .core .converter .AnnotatedType (type ).resolveAsRef (true ));
211+ // could not resolve the schema or this schema references other schema
212+ // we dont want this since there's no reference to the components in order to register a new schema if it doesnt already exist
213+ // defaulting to string
214+ if (resolvedSchema == null || !resolvedSchema .referencedSchemas .isEmpty ()) {
215+ schema = PrimitiveType .fromType (String .class ).createProperty ();
216+ } else {
217+ schema = resolvedSchema .schema ;
218+ }
219+ }
220+ parameter .setSchema (schema );
221+ }
222+ return parameter ;
223+ }
224+ }
0 commit comments