setup.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. """Setup utility for gcld3."""
  2. import os
  3. import platform
  4. import shutil
  5. import subprocess
  6. import setuptools
  7. from setuptools.command import build_ext
  8. __version__ = '3.0.13'
  9. _NAME = 'gcld3'
  10. REQUIREMENTS = ['pybind11 >= 2.5.0', 'wheel >= 0.34.2']
  11. PROTO_FILES = [
  12. 'src/feature_extractor.proto',
  13. 'src/sentence.proto',
  14. 'src/task_spec.proto',
  15. ]
  16. SRCS = [
  17. 'src/base.cc',
  18. 'src/embedding_feature_extractor.cc',
  19. 'src/embedding_network.cc',
  20. 'src/feature_extractor.cc',
  21. 'src/feature_types.cc',
  22. 'src/fml_parser.cc',
  23. 'src/lang_id_nn_params.cc',
  24. 'src/language_identifier_features.cc',
  25. 'src/language_identifier_main.cc',
  26. 'src/nnet_language_identifier.cc',
  27. 'src/registry.cc',
  28. 'src/relevant_script_feature.cc',
  29. 'src/sentence_features.cc',
  30. 'src/task_context.cc',
  31. 'src/task_context_params.cc',
  32. 'src/unicodetext.cc',
  33. 'src/utils.cc',
  34. 'src/workspace.cc',
  35. 'src/script_span/fixunicodevalue.cc',
  36. 'src/script_span/generated_entities.cc',
  37. 'src/script_span/generated_ulscript.cc',
  38. 'src/script_span/getonescriptspan.cc',
  39. 'src/script_span/offsetmap.cc',
  40. 'src/script_span/text_processing.cc',
  41. 'src/script_span/utf8statetable.cc',
  42. # These CC files have to be generated by the proto buffer compiler 'protoc'
  43. 'src/cld_3/protos/feature_extractor.pb.cc',
  44. 'src/cld_3/protos/sentence.pb.cc',
  45. 'src/cld_3/protos/task_spec.pb.cc',
  46. # pybind11 bindings
  47. 'gcld3/pybind_ext.cc',
  48. ]
  49. class CompileProtos(build_ext.build_ext):
  50. """Compile protocol buffers via `protoc` compiler."""
  51. def run(self):
  52. if shutil.which('protoc') is None:
  53. raise RuntimeError('Please install the proto buffer compiler.')
  54. # The C++ code expect the protos to be compiled under the following
  55. # directory, therefore, create it if necessary.
  56. compiled_protos_dir = 'src/cld_3/protos/'
  57. os.makedirs(compiled_protos_dir, exist_ok=True)
  58. command = ['protoc', f'--cpp_out={compiled_protos_dir}', '--proto_path=src']
  59. command.extend(PROTO_FILES)
  60. subprocess.run(command, check=True, cwd='./')
  61. build_ext.build_ext.run(self)
  62. class PyBindIncludes(object):
  63. """Returns the include paths for pybind11 when needed.
  64. To delay the invocation of "pybind11.get_include()" until it is available
  65. in the environment. This lazy evaluation allows us to install it first, then
  66. import it later to determine the correct include paths.
  67. """
  68. def __str__(self):
  69. import pybind11 # pylint: disable=g-import-not-at-top
  70. return pybind11.get_include()
  71. MACOS = platform.system() == 'Darwin'
  72. ext_modules = [
  73. setuptools.Extension(
  74. 'gcld3.pybind_ext',
  75. sorted(SRCS),
  76. include_dirs=[
  77. PyBindIncludes(),
  78. ],
  79. libraries=['protobuf'],
  80. extra_compile_args=['-std=c++11', '-stdlib=libc++'] if MACOS else [],
  81. extra_link_args=['-stdlib=libc++'] if MACOS else [],
  82. language='c++'),
  83. ]
  84. DESCRIPTION = """CLD3 is a neural network model for language identification.
  85. This package contains the inference code and a trained model. See
  86. https://github.com/google/cld3 for more details.
  87. """
  88. setuptools.setup(
  89. author='Rami Al-Rfou',
  90. author_email='rmyeid@google.com',
  91. cmdclass={
  92. 'build_ext': CompileProtos,
  93. },
  94. ext_modules=ext_modules,
  95. packages=setuptools.find_packages(),
  96. description='CLD3 is a neural network model for language identification.',
  97. long_description=DESCRIPTION,
  98. name=_NAME,
  99. setup_requires=REQUIREMENTS,
  100. url='https://github.com/google/cld3',
  101. version=__version__,
  102. zip_safe=False,
  103. )