diff --git a/setup.py b/setup.py
index 35e9952edb86f1990965ebd9329ca8f3826ab99c..9a857c11fdbf14e4c11f579be152b6486dafe239 100644
--- a/setup.py
+++ b/setup.py
@@ -17,11 +17,12 @@ with (HERE / 'README.md').open('rt') as fh:
 REQUIREMENTS: dict = {
     'core': [
         'dataclasses;python_version<"3.7"',
-        'JPype1>=1.2.1,<2.*',
+        'JPype1>=1.2.1,<2.0.dev0',
     ],
     'test': [
         'pytest',
         'mypy>=0.931,<0.971',
+        "typing_extensions;python_version<'3.8'",  # Required for java-stubs
     ],
 }
 
diff --git a/stubgenj/_stubgenj.py b/stubgenj/_stubgenj.py
index 504ab80e00edf1afe3bdc37480185f5d43ca87c3..20960b56a913787c4049caa8676f15bae28b5ece 100644
--- a/stubgenj/_stubgenj.py
+++ b/stubgenj/_stubgenj.py
@@ -31,6 +31,7 @@ import dataclasses
 import functools
 import pathlib
 import re
+import textwrap
 from typing import Dict, List, Optional, Any, Set, Type, Union, Generator
 
 import jpype
@@ -197,12 +198,24 @@ def generateJPypeJPackageOverloadStubs(outputPath: pathlib.Path, topLevelPackage
 
     imports = []
     overloads = []
+
+    if topLevelPackages:
+        imports.append(textwrap.dedent(
+            """
+            import sys
+            if sys.version_info >= (3, 8):
+                from typing import Literal
+            else:
+                from typing_extensions import Literal
+            """,
+        ))
+
     for name in topLevelPackages:
         imports.append(f"import {name}")
         overloads.extend([
             '',
             '@typing.overload',
-            f'def JPackage(__package_name: typing.Literal[\'{name}\']) -> {name}.__module_protocol__: ...\n',
+            f'def JPackage(__package_name: Literal[\'{name}\']) -> {name}.__module_protocol__: ...\n',
         ])
 
     with jpypeStubsPath.open('wt') as fh:
@@ -348,9 +361,18 @@ def generateModuleProtocol(
     """ Mutate the given import and class output to include a __module_protocol__ typing.Protocol """
 
     importOutput.append('import typing')
+    importOutput.append(textwrap.dedent(
+        """
+        import sys
+        if sys.version_info >= (3, 8):
+            from typing import Protocol
+        else:
+            from typing_extensions import Protocol
+        """,
+    ))
 
     protocolOutput = [
-        'class __module_protocol__(typing.Protocol):',
+        'class __module_protocol__(Protocol):',
         f'    # A module protocol which reflects the result of ``jp.JPackage("{pkgName}")``.',
         '',
     ]