#!/usr/bin/env python3
# Copyright (C) 2018 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import re
import sys
from codecs import open

PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
SELF_PATH = os.path.relpath(__file__, PROJECT_ROOT).replace('\\', '/')

CONFIG_PROTO_ROOTS = [
    'protos/perfetto/common/data_source_descriptor.proto',
    'protos/perfetto/common/tracing_service_state.proto',
    'protos/perfetto/config/trace_config.proto'
]
MERGED_CONFIG_PROTO = 'protos/perfetto/config/perfetto_config.proto'

TRACE_PROTO_ROOTS = CONFIG_PROTO_ROOTS + [
    'protos/perfetto/trace/trace.proto',
]
MERGED_TRACE_PROTO = 'protos/perfetto/trace/perfetto_trace.proto'

METRICS_PROTOS_ROOTS = ['protos/perfetto/metrics/metrics.proto']
MERGED_METRICS_PROTO = 'protos/perfetto/metrics/perfetto_merged_metrics.proto'

REPLACEMENT_HEADER = '''
// AUTOGENERATED - DO NOT EDIT
// ---------------------------
// This file has been generated by
// AOSP://external/perfetto/%s
// merging the perfetto config protos.
// This fused proto is intended to be copied in:
//  - Android tree, for statsd.
//  - Google internal repos.

syntax = "proto2";

package perfetto.protos;

option go_package = "github.com/google/perfetto/perfetto_proto";
'''


def get_transitive_imports(rel_path, visited):
  if rel_path in visited:
    return []
  visited.add(rel_path)
  with open(os.path.join(PROJECT_ROOT, rel_path), 'r', encoding='utf-8') as f:
    content = f.read()
  imports = re.findall(r'^import "(.*)";\n', content, flags=re.MULTILINE)
  res = []
  for child in sorted(imports):
    res += get_transitive_imports(child, visited)
  res += [rel_path]
  return res


def merge_protos_content(proto_paths):
  merged_content = REPLACEMENT_HEADER.lstrip() % SELF_PATH
  added_files = set()
  for proto in proto_paths:
    if proto in added_files:
      continue
    added_files.add(proto)

    path = os.path.join(PROJECT_ROOT, proto)
    with open(path, 'r', encoding='utf-8') as f:
      content = f.read()

    # Remove header
    header = re.match(r'\/(\*|\/)(?s:.)*?package .*;\n', content)
    if header is None:
      raise Exception('Proto file ' + path + ' does not specify a package')
    header = header.group(0)
    content = content[len(header):]

    content = re.sub(r'^import.*?\n\n?', '', content, flags=re.MULTILINE)
    merged_content += '\n// Begin of %s\n' % proto
    merged_content += content
    merged_content += '\n// End of %s\n' % proto

  definitions_re = r'^ *(?:message|enum) ([A-Z][A-Za-z0-9].*) {'
  definitions = re.finditer(definitions_re, merged_content, re.MULTILINE)
  types = set((match.group(1) for match in definitions))

  # Limitation: |types| doesn't track the nesting of messages, so a reference to
  # a nested message (optional One.Two f = 1;) is simplified to its leafmost
  # name (Two in this example).
  uses_re = r'^( +)(?:repeated)?(?:optional)?\s?'\
      r'(?:[A-Z]\w+\.)*([A-Z]\w+)\s+[a-z]\w*\s*=\s*(\d+);'
  uses = re.finditer(uses_re, merged_content, re.MULTILINE)
  substitutions = []
  for use in uses:
    everything = use.group(0)
    indentation = use.group(1)
    used_type = use.group(2)
    field_number = use.group(3)
    if used_type not in types:
      replacement = '{}// removed field with id {}'.format(
          indentation, field_number)
      substitutions.append((everything, replacement))

  for before, after in substitutions:
    merged_content = merged_content.replace(before, after)

  return merged_content


def merge_protos(root_paths, output_path):
  all_protos = []
  for root_path in root_paths:
    all_protos += get_transitive_imports(root_path, visited=set())
  merged_content = merge_protos_content(all_protos)

  out_path = os.path.join(PROJECT_ROOT, output_path)
  prev_content = None
  if os.path.exists(out_path):
    with open(out_path, 'r', encoding='utf-8') as fprev:
      prev_content = fprev.read()

  if prev_content == merged_content:
    return True

  if '--check-only' in sys.argv:
    return False

  print('Updating {}'.format(output_path))
  with open(out_path, 'w', encoding='utf-8') as fout:
    fout.write(merged_content)
  return True


def main():
  result = merge_protos(CONFIG_PROTO_ROOTS, MERGED_CONFIG_PROTO)
  result &= merge_protos(TRACE_PROTO_ROOTS, MERGED_TRACE_PROTO)
  result &= merge_protos(METRICS_PROTOS_ROOTS, MERGED_METRICS_PROTO)
  return 0 if result else 1


if __name__ == '__main__':
  sys.exit(main())
