0xDEADBEEF

RSS odkazy

diff-compress.scala

18. 9. 2016 #kód
case class Segment(hash: Int, str: String)

sealed trait DiffGroup
case class NewLine(str: String) extends DiffGroup
case class Diffs(diffs: Seq[Diff]) extends DiffGroup

sealed trait Diff { def pos: Int }
case class Addition(pos: Int, str: String) extends Diff
case class Deletion(pos: Int, len: Int) extends Diff



// Longest common subsequence
def lcs[T](a: Seq[T], b: Seq[T]): Seq[T] = {
  val alen = a.length
  val blen = b.length

  val lengths: Array[Array[Int]] = Array.fill(alen+1) { new Array[Int](blen+1) }

  // row 0 and column 0 are initialized to 0 already

  var i = 0; while (i < alen) {
    var j = 0; while (j < blen) {
      if (a(i) == b(j)) {
        lengths(i+1)(j+1) = lengths(i)(j) + 1
      } else {
        lengths(i+1)(j+1) = math.max(lengths(i+1)(j), lengths(i)(j+1))
      }

      j += 1
    }
    i += 1
  }

  // read the substring out from the matrix
  var res = List[T]()
  var x = alen
  var y = blen
  while (x != 0 && y != 0) {
    if (lengths(x)(y) == lengths(x-1)(y))
      x -= 1
    else if (lengths(x)(y) == lengths(x)(y-1))
      y -= 1
    else {
      assert(a(x-1) == b(y-1))
      res = a(x-1) :: res
      x -= 1
      y -= 1
    }
  }

  res
}


def splitWithDelimiter(str: String, delim: String): Array[String] = {
  val splits = str.split(delim)
  if (str.endsWith(delim)) {
    splits.map(s => s + delim)
  } else {
    splits.init.map(s => s + delim) :+ splits.last
  }
}


def makeDiff(a: Seq[Segment], b: Seq[Segment]): Seq[Diff] = {

  val diffs = new collection.mutable.ArrayBuilder.ofRef[Diff]()

  val cs = lcs(a, b)

  var apos, bpos = 0
  var astrpos = 0

  for (c <- cs) {

    val aDiffPos = astrpos

    //val asb = new StringBuilder()
    var delLen = 0
    while (apos < a.length && a(apos) != c) {
      astrpos += a(apos).str.length
      delLen += a(apos).str.length
      //asb.append(a(apos).str)
      apos += 1
    }
    if (delLen > 0) {
      diffs += Deletion(aDiffPos, delLen)
    }
    astrpos += a(apos).str.length
    apos += 1

    val bsb = new java.lang.StringBuilder()
    while (bpos < b.length && b(bpos) != c) {
      bsb.append(b(bpos).str)
      bpos += 1
    }
    if (bsb.length > 0) {
      diffs += Addition(aDiffPos, bsb.toString)
    }
    bpos += 1

  }

  val aDiffPos = astrpos

  var delLen = 0
  while (apos < a.length) {
    delLen += a(apos).str.length
    apos += 1
  }
  if (delLen > 0) {
    diffs += Deletion(aDiffPos, delLen)
  }

  val bsb = new StringBuilder()
  while (bpos < b.length) {
    bsb.append(b(bpos).str)
    bpos += 1
  }
  if (bsb.length > 0) {
    diffs += Addition(aDiffPos, bsb.toString)
  }

  diffs.result.toSeq

}

def renderDiff(diffs: Seq[Diff]): String =
  (diffs map {
    case Addition(pos, str) => s"+ $pos $str"
    case Deletion(pos, len) => s"- $pos $len"
  } mkString "\n") + "\n---"





def compressLines(lines: Iterator[String], delimiter: String): Iterator[String] = {
  var first = true

  (
  lines
    .map(s => (s, splitWithDelimiter(s, delimiter).map(s => Segment(s.hashCode, s))))
    .sliding(2)
    .flatMap { case Seq((astr, aseg), (bstr, bseg)) =>

      val diff = renderDiff(makeDiff(aseg, bseg))

      val diffStr =
        if (diff.length > bstr.length) {
          "* "+bstr
        } else {
          diff
        }

      if (first) {
        first = false

        Seq(
          "* "+astr,
          diffStr
        )
      } else {
        Seq(diffStr)
      }
    }
  )
}


def decompressLines(lines: Iterator[String]): Iterator[String] = {
  val groups = parseLines(lines)
  assert(groups.head.isInstanceOf[NewLine])
  val first = groups.head.asInstanceOf[NewLine].str

  groups.tail.iterator.scanLeft(first) { (prevLine, diffGroup) =>
    diffGroup match {
      case NewLine(str) => str
      case Diffs(diffs) =>

        var builderLen = prevLine.length
        diffs foreach {
          case Deletion(pos, len) => builderLen -= len
          case Addition(pos, str) => builderLen += str.length
        }

        val sb = new java.lang.StringBuilder(builderLen)
        var srcPos = 0

        diffs.sortBy(d => (d.pos, !d.isInstanceOf[Addition])) foreach { // deletions first
          case Deletion(pos, len) =>
            if (srcPos < pos) {
              sb.append(prevLine, srcPos, pos)
            }
            srcPos = pos + len
          case Addition(pos, str) =>
            if (srcPos < pos) {
              sb.append(prevLine, srcPos, pos)
            }
            sb.append(str)
            srcPos = pos
        }

        if (srcPos < prevLine.length) {
          sb.append(prevLine, srcPos, prevLine.length)
        }

        assert(sb.length == builderLen)

        sb.toString
    }
  }
}



def parseLines(lines: Iterator[String]): Stream[DiffGroup] = {
  var stream = lines.toStream

  def chompDiffGroup(): DiffGroup  = {
    val _ls = stream.takeWhile(l => l != "---" && !l.startsWith("* "))
    val ls = stream.take(_ls.length + 1)
    stream = stream.drop(_ls.length + 1)

    if (ls.size == 0) {
      null

    } else if (ls.head.startsWith("* ")) {
      assert(ls.size == 1)
      NewLine(ls.head.drop(2))

    } else {
      Diffs(ls.init.map { l =>
        l.split(" ", 3) match {
          case Array("+", posStr, str) =>
            val pos = posStr.toInt
            Addition(pos, str)
          case Array("-", posStr, lenStr) =>
            val pos = posStr.toInt
            val len = lenStr.toInt
            Deletion(pos, len)
        }
      })
    }
  }

  Stream.continually(chompDiffGroup()).takeWhile(_ != null)
}


args match {
  case Array("compress", delimiter, in) =>
    val lines = io.Source.fromFile(in).getLines
    compressLines(lines, delimiter) foreach println

  case Array("decompress", in) =>
    val lines = io.Source.fromFile(in).getLines
    decompressLines(lines) foreach println

  case _ => System.err.println("""
    |arguments:
    |  compress $delimiter $input
    |  decompress $input
    """.trim.stripMargin('|'))
}
píše k47 (@kaja47, k47)