diff --git a/stubgenj/__main__.py b/stubgenj/__main__.py
index 0806ec0a7863e0eafa2f0cf991ac8fcd13c6f5c8..4f89642471b73f6e711da53ca4f3ed3e52dacaac 100644
--- a/stubgenj/__main__.py
+++ b/stubgenj/__main__.py
@@ -18,25 +18,29 @@ if __name__ == '__main__':
                         help='package prefixes to generate stubs for (e.g. org.myproject)')
     parser.add_argument('--jvmpath', type=str,
                         help='path to the JVM ("libjvm.so", "jvm.dll", ...) (default: use system default JVM)')
-    parser.add_argument('--classpath', type=str,
+    parser.add_argument('--classpath', type=str, default='.',
                         help='java class path to use, separated by ":". '
                              'glob-like expressions (e.g. dir/*.jar) are supported (default: .)')
-    parser.add_argument('--output-dir', type=str,
+    parser.add_argument('--output-dir', type=str, default='.',
                         help='path to write stubs to (default: .)')
-    parser.add_argument('--convert-strings', dest='convert_strings', action='store_true',
+    parser.add_argument('--convert-strings', dest='convert_strings', action='store_true', default=False,
                         help='convert java.lang.String to python str in return types. '
                              'consult the JPype documentation on the convertStrings flag for details')
-    parser.add_argument('--no-stubs-suffix', dest='stubs_suffix', action='store_true',
+    parser.add_argument('--no-stubs-suffix', dest='with_stubs_suffix', action='store_false', default=True,
                         help='do not use PEP-561 "-stubs" suffix for top-level packages')
-
-    parser.set_defaults(stubs_suffix=True, classpath='.', output_dir='.', convert_strings=False)
+    parser.add_argument('--no-jpackage-stubs', dest='with_jpackage_stubs', action='store_false', default=True,
+                        help='do not create a partial jpype-stubs package for jp.JPackage("<tld>") type interfaces')
 
     args = parser.parse_args()
     classpath = [c for c_in in args.classpath.split(':') for c in glob(c_in)]
     log.info('Starting JPype JVM with classpath ' + str(classpath))
     jpype.startJVM(jvmpath=args.jvmpath, classpath=classpath, convertStrings=args.convert_strings)  # noqa: exists
     prefixPackages = [importlib.import_module(prefix) for prefix in args.prefixes]
-
-    generateJavaStubs(prefixPackages, useStubsSuffix=args.stubs_suffix, outputDir=args.output_dir)
+    generateJavaStubs(
+        prefixPackages,
+        useStubsSuffix=args.with_stubs_suffix,
+        outputDir=args.output_dir,
+        jpypeJPackageStubs=args.with_jpackage_stubs,
+    )
     log.info('Generation done.')
     jpype.java.lang.Runtime.getRuntime().halt(0)
diff --git a/stubgenj/_stubgenj.py b/stubgenj/_stubgenj.py
index cda445f36133a8e35e068082f3c8588b80ebbfeb..a842bb23cbe90aae5a512ea92d0a4a01310670f8 100644
--- a/stubgenj/_stubgenj.py
+++ b/stubgenj/_stubgenj.py
@@ -26,6 +26,7 @@ Authors:
     P. Elson        <philip.elson@cern.ch>
 """
 
+import collections
 import dataclasses
 import functools
 import pathlib
@@ -82,8 +83,11 @@ def packageAndSubPackages(package: jpype.JPackage) -> Generator[jpype.JPackage,
             log.warning(f'skipping {package.__name__}.{name}: {e}')
 
 
-def generateJavaStubs(parentPackages: List[jpype.JPackage], useStubsSuffix: bool = True,
-                      outputDir: Union[str, pathlib.Path] = '.') -> None:
+def generateJavaStubs(parentPackages: List[jpype.JPackage],
+                      useStubsSuffix: bool = True,
+                      outputDir: Union[str, pathlib.Path] = '.',
+                      jpypeJPackageStubs: bool = True,
+                      ) -> None:
     """
     Main entry point. Recursively generate stubs for the provided packages and all sub-packages.
     This method assumes that a JPype JVM was started with a proper classpath and the JPype import system is enabled.
@@ -97,6 +101,17 @@ def generateJavaStubs(parentPackages: List[jpype.JPackage], useStubsSuffix: bool
 
     log.info(f'Collected {len(packages)} packages ...')
 
+    # Map package names to a list of direct subpackages
+    # (e.g {'foo.bar': ['wibble', 'wobble']}).
+    subpackages = collections.defaultdict(list)
+    for pkg in packages:
+        # If this package is a subpackage (i.e. it has a "." in the name) then
+        # get its parent's name, and add the package to the parent's list of
+        # subpackages.
+        if '.' in pkg.__name__:
+            parent, name = pkg.__name__.rsplit('.', 1)
+            subpackages[parent].append(name)
+
     outputPath = pathlib.Path(outputDir)
     for pkg in packages:
         pathParts = pkg.__name__.split('.')
@@ -109,7 +124,43 @@ def generateJavaStubs(parentPackages: List[jpype.JPackage], useStubsSuffix: bool
             initFile = submodulePath / '__init__.pyi'
             initFile.touch()
 
-        generateStubsForJavaPackage(pkg, submodulePath / '__init__.pyi')
+        generateStubsForJavaPackage(pkg, submodulePath / '__init__.pyi', subpackages[pkg.__name__])
+
+    if jpypeJPackageStubs:
+        tld_packages = {name.split('.')[0] for name in subpackages}
+        generateJPypeJPackageOverloadStubs(outputPath / 'jpype-stubs', sorted(tld_packages))
+
+
+def generateJPypeJPackageOverloadStubs(outputPath: pathlib.Path, topLevelPackages: List[str]):
+    """ Generate context for a jpype-stubs directory containing JPackage overloads for the given TLDs. """
+    outputPath.mkdir(parents=True, exist_ok=True)
+
+    log.info(f'Generating jpype-stubs for tld JPackages: {", ".join(topLevelPackages)}')
+
+    # Following the guidance at https://www.python.org/dev/peps/pep-0561/#partial-stub-packages
+    # we ensure that other type stubs for JPype are honoured (unless they are also defined
+    # in a different "jpype-stubs" directory in site-packages).
+    (outputPath / 'py.typed').write_text('partial\n')
+    jpypeStubsPath = outputPath / '__init__.pyi'
+
+    imports = []
+    overloads = []
+    for name in topLevelPackages:
+        imports.append(f"import {name}")
+        overloads.extend([
+            '',
+            '@typing.overload',
+            f'def JPackage(__package_name: typing.Literal[\'{name}\']) -> {name}.__module_protocol__: ...\n',
+        ])
+
+    with jpypeStubsPath.open('wt') as fh:
+        fh.writelines([
+            'import types\n',
+            'import typing\n\n',
+            '\n'.join(imports) + '\n\n',
+            '\n'.join(overloads) + '\n\n',
+            'def JPackage(__package_name) -> types.ModuleType: ...\n\n',
+        ])
 
 
 def filterClassNamesInPackage(packageName: str, types: Set[str]) -> Set[str]:
@@ -148,10 +199,11 @@ def provideCustomizerStubs(customizersUsed: Set[Type], importOutput: List[str],
         importOutput.append(f'from {c.__module__} import {c.__qualname__}')
 
 
-def generateStubsForJavaPackage(package: jpype.JPackage, outputFile: str) -> None:
+def generateStubsForJavaPackage(package: jpype.JPackage, outputFile: str, subpackages: List[str]) -> None:
     """ Generate stubs for a single Java package, represented as a python package with a single __init__ module. """
-    javaClasses = list(packageClasses(package))
-    log.info(f'Generating stubs for {package.__name__} ({len(javaClasses)} classes)')
+    pkgName = package.__name__
+    javaClasses = sorted(packageClasses(package), key=lambda pkg: pkg.__name__)
+    log.info(f'Generating stubs for {pkgName} ({len(javaClasses)} classes, {len(subpackages)} subpackages)')
 
     importOutput = []  # type: List[str]
     classOutput = []  # type: List[str]
@@ -172,7 +224,7 @@ def generateStubsForJavaPackage(package: jpype.JPackage, outputFile: str) -> Non
         #  - first, we attempt to get them by explicitly reading the attribute from the JPackage object. This may work
         #    for certain protected or module internal (Java 11) classes.
         #  - failing that, we generate an empty stub.
-        missingPrivateClasses = filterClassNamesInPackage(package.__name__, classesUsed) - classesDone
+        missingPrivateClasses = filterClassNamesInPackage(pkgName, classesUsed) - classesDone
         for missingPrivateClass in sorted(missingPrivateClasses):
             cls = getattr(package, missingPrivateClass, None)
 
@@ -202,18 +254,21 @@ def generateStubsForJavaPackage(package: jpype.JPackage, outputFile: str) -> Non
                 classOutput.append('')
                 generateEmptyClassStub(missingPrivateClass, classesDone=classesDone, output=classOutput)
 
-    if any(('typing.' in line) for line in classOutput):
-        importOutput.append('import typing')
+    generateModuleProtocol(
+        pkgName,
+        sorted([className for className in classesDone if '$' not in className]),
+        subpackages, importOutput, classOutput,
+    )
 
     if customizersUsed:
         provideCustomizerStubs(customizersUsed, importOutput, outputFile)
 
-    output = []
+    output = ['import typing\n']
 
     for line in sorted(set(importOutput)):
         output.append(line)
 
-    output.append('')
+    output.extend([''] * 2)
     for line in classOutput:
         output.append(line)
     with open(outputFile, 'w') as file:
@@ -221,6 +276,35 @@ def generateStubsForJavaPackage(package: jpype.JPackage, outputFile: str) -> Non
             file.write(f'{line}\n')
 
 
+def generateModuleProtocol(
+        pkgName: str,
+        classesInModule: List[str],
+        subpackages: List[str],
+        importOutput: List[str],
+        classOutput: List[str]
+) -> None:
+    """ Mutate the given import and class output to include a __module_protocol__ typing.Protocol """
+
+    protocolOutput = [
+        'class __module_protocol__(typing.Protocol):',
+        f'    # A module protocol which reflects the result of ``jp.JPackage("{pkgName}")``.',
+        '',
+    ]
+
+    for className in classesInModule:
+        protocolOutput.append(f'    {className}: typing.Type[{className}]')
+
+    for subpackage_name in subpackages:
+        importOutput.append(f'import {pkgName}.{subpackage_name}')
+        protocolOutput.append(f'    {subpackage_name}: {pkgName}.{subpackage_name}.__module_protocol__')
+    if not classesInModule and not subpackages:
+        protocolOutput.append('    pass')
+
+    if classOutput:
+        classOutput.extend([''] * 2)
+    classOutput.extend(protocolOutput)
+
+
 def isJavaClass(obj: type) -> bool:
     """ Check if a type is a 'real' Java class. This excludes synthetic/anonymous Java classes.