3333import com .oracle .graal .python .builtins .objects .function .PFunction ;
3434import com .oracle .graal .python .builtins .objects .object .PythonBuiltinObject ;
3535import com .oracle .graal .python .builtins .objects .object .PythonObject ;
36+ import com .oracle .graal .python .compiler .CodeUnit ;
3637import com .oracle .graal .python .nodes .bytecode .BytecodeFrameInfo ;
3738import com .oracle .graal .python .nodes .bytecode .FrameInfo ;
3839import com .oracle .graal .python .nodes .bytecode .GeneratorYieldResult ;
@@ -66,8 +67,6 @@ public class PGenerator extends PythonBuiltinObject {
6667 private boolean finished ;
6768 // running means it is currently on the stack, not just started
6869 private boolean running ;
69- private final boolean isCoroutine ;
70- private final boolean isAsyncGen ;
7170
7271 private PCode code ;
7372
@@ -117,21 +116,20 @@ public Object handleResult(PythonLanguage language, GeneratorYieldResult result)
117116 public static class BytecodeDSLState {
118117 private final PBytecodeDSLRootNode rootNode ;
119118 private final Object [] arguments ;
120- private BytecodeNode bytecodeNode ;
119+ private BytecodeLocation lastLocation ;
121120 private ContinuationRootNode continuationRootNode ;
122121 private boolean isStarted ;
123122
124123 public BytecodeDSLState (PBytecodeDSLRootNode rootNode , Object [] arguments , ContinuationRootNode continuationRootNode ) {
125124 this .rootNode = rootNode ;
126125 this .arguments = arguments ;
127126 this .continuationRootNode = continuationRootNode ;
128- this .bytecodeNode = rootNode .getBytecodeNode ();
129127 }
130128
131129 public Object handleResult (PGenerator generator , ContinuationResult result ) {
132130 assert result .getContinuationRootNode () == null || result .getContinuationRootNode ().getFrameDescriptor () == generator .frame .getFrameDescriptor ();
133131 isStarted = true ;
134- bytecodeNode = continuationRootNode .getLocation (). getBytecodeNode ();
132+ lastLocation = continuationRootNode .getLocation ();
135133 continuationRootNode = result .getContinuationRootNode ();
136134 return result .getResult ();
137135 }
@@ -148,40 +146,26 @@ private BytecodeDSLState getBytecodeDSLState() {
148146 return (BytecodeDSLState ) state ;
149147 }
150148
151- // An explicit isIterableCoroutine argument is needed for iterable coroutines (generally created
152- // via types.coroutine)
153149 public static PGenerator create (PythonLanguage lang , PFunction function , PBytecodeRootNode rootNode , RootCallTarget [] callTargets , Object [] arguments ,
154- PythonBuiltinClassType cls , boolean isIterableCoroutine ) {
150+ PythonBuiltinClassType cls ) {
155151 // note: also done in PAsyncGen.create
156152 MaterializedFrame generatorFrame = rootNode .createGeneratorFrame (arguments );
157- return new PGenerator (lang , function , generatorFrame , cls , isIterableCoroutine , new BytecodeState (rootNode , callTargets ));
158- }
159-
160- public static PGenerator create (PythonLanguage lang , PFunction function , PBytecodeDSLRootNode rootNode , Object [] arguments ,
161- PythonBuiltinClassType cls , boolean isIterableCoroutine , ContinuationRootNode continuationRootNode , MaterializedFrame continuationFrame ) {
162- return new PGenerator (lang , function , continuationFrame , cls , isIterableCoroutine , new BytecodeDSLState (rootNode , arguments , continuationRootNode ));
163- }
164-
165- public static PGenerator create (PythonLanguage lang , PFunction function , PBytecodeRootNode rootNode , RootCallTarget [] callTargets , Object [] arguments ,
166- PythonBuiltinClassType cls ) {
167- return create (lang , function , rootNode , callTargets , arguments , cls , false );
153+ return new PGenerator (lang , function , generatorFrame , cls , new BytecodeState (rootNode , callTargets ));
168154 }
169155
170156 public static PGenerator create (PythonLanguage lang , PFunction function , PBytecodeDSLRootNode rootNode , Object [] arguments ,
171157 PythonBuiltinClassType cls , ContinuationRootNode continuationRootNode , MaterializedFrame continuationFrame ) {
172- return create (lang , function , rootNode , arguments , cls , false , continuationRootNode , continuationFrame );
158+ return new PGenerator (lang , function , continuationFrame , cls , new BytecodeDSLState ( rootNode , arguments , continuationRootNode ) );
173159 }
174160
175- protected PGenerator (PythonLanguage lang , PFunction function , MaterializedFrame frame , PythonBuiltinClassType cls , boolean isIterableCoroutine , Object state ) {
161+ protected PGenerator (PythonLanguage lang , PFunction function , MaterializedFrame frame , PythonBuiltinClassType cls , Object state ) {
176162 super (cls , cls .getInstanceShape (lang ));
177163 this .name = function .getName ();
178164 this .qualname = function .getQualname ();
179165 this .globals = function .getGlobals ();
180166 this .generatorFunction = function ;
181167 this .frame = frame ;
182168 this .finished = false ;
183- this .isCoroutine = isIterableCoroutine || cls == PythonBuiltinClassType .PCoroutine ;
184- this .isAsyncGen = cls == PythonBuiltinClassType .PAsyncGenerator ;
185169 if (PythonOptions .ENABLE_BYTECODE_DSL_INTERPRETER ) {
186170 BytecodeDSLState bytecodeDSLState = (BytecodeDSLState ) state ;
187171 this .state = state ;
@@ -311,7 +295,12 @@ public RootCallTarget getCurrentCallTarget() {
311295 */
312296 public BytecodeNode getBytecodeNode () {
313297 assert PythonOptions .ENABLE_BYTECODE_DSL_INTERPRETER ;
314- return getBytecodeDSLState ().bytecodeNode ;
298+ BytecodeDSLState state = getBytecodeDSLState ();
299+ if (state .lastLocation != null ) {
300+ return state .lastLocation .getBytecodeNode ();
301+ } else {
302+ return state .rootNode .getBytecodeNode ();
303+ }
315304 }
316305
317306 public BytecodeLocation getCurrentLocation () {
@@ -398,12 +387,21 @@ public final void setQualname(TruffleString qualname) {
398387 this .qualname = qualname ;
399388 }
400389
390+ private CodeUnit getCodeUnit () {
391+ if (PythonOptions .ENABLE_BYTECODE_DSL_INTERPRETER ) {
392+ return getBytecodeDSLState ().rootNode .getCodeUnit ();
393+ } else {
394+ return getBytecodeState ().rootNode .getCodeUnit ();
395+ }
396+ }
397+
401398 public final boolean isCoroutine () {
402- return isCoroutine ;
399+ CodeUnit codeUnit = getCodeUnit ();
400+ return codeUnit .isCoroutine () || codeUnit .isIterableCoroutine ();
403401 }
404402
405403 public final boolean isAsyncGen () {
406- return isAsyncGen ;
404+ return getCodeUnit (). isAsyncGenerator () ;
407405 }
408406
409407 public int getBci () {
0 commit comments