Skip to content

Commit

Permalink
✨ feat: Add shift left, shift right, mod
Browse files Browse the repository at this point in the history
  • Loading branch information
caoccao committed Oct 28, 2024
1 parent d3ec183 commit 54ccc39
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,22 @@ class Test {
return a / b;
}

public mod_II_I(a: int, b: int): int {
return a % b;
}

public multiply_II_I(a: int, b: int): int {
return a * b;
}

public shift_left_II_I(a: int, b: int): int {
return a << b;
}

public shift_right_II_I(a: int, b: int): int {
return a >> b;
}

public subtract_II_I(a: int, b: int): int {
return a - b;
}
Expand All @@ -25,5 +37,8 @@ class Test {
console.log(new Test().add_II_I(1, 2));
console.log(new Test().add_IL_L(1, 2));
console.log(new Test().divide_II_I(3, 2));
console.log(new Test().mod_II_I(3, 2));
console.log(new Test().multiply_II_I(3, 2));
console.log(new Test().shift_left_II_I(3, 2));
console.log(new Test().shift_right_II_I(3, 1));
console.log(new Test().subtract_II_I(3, 2));
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,18 @@ public void manipulate(JavaFunctionContext functionContext, Swc4jAstBinExpr ast)
case Div:
stackManipulation = JavaClassCast.getDivision(upCastClass);
break;
case LShift:
stackManipulation = JavaClassCast.getShiftLeft(upCastClass);
break;
case Mod:
stackManipulation = JavaClassCast.getRemainder(upCastClass);
break;
case Mul:
stackManipulation = JavaClassCast.getMultiplication(upCastClass);
break;
case RShift:
stackManipulation = JavaClassCast.getShiftRight(upCastClass);
break;
case Sub:
stackManipulation = JavaClassCast.getSubtraction(upCastClass);
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,47 +30,100 @@ private JavaClassCast() {
}

public static Addition getAddition(Class<?> clazz) {
if (clazz == long.class) {
if (clazz == int.class) {
return Addition.INTEGER;
} else if (clazz == long.class) {
return Addition.LONG;
} else if (clazz == float.class) {
return Addition.FLOAT;
} else if (clazz == double.class) {
return Addition.DOUBLE;
}
return Addition.INTEGER;
throw new Ts2JavaException(
SimpleFreeMarkerFormat.format("Unsupported class ${class} in addition",
SimpleMap.of("class", clazz.getName())));
}

public static Division getDivision(Class<?> clazz) {
if (clazz == long.class) {
if (clazz == int.class) {
return Division.INTEGER;
} else if (clazz == long.class) {
return Division.LONG;
} else if (clazz == float.class) {
return Division.FLOAT;
} else if (clazz == double.class) {
return Division.DOUBLE;
}
return Division.INTEGER;
throw new Ts2JavaException(
SimpleFreeMarkerFormat.format("Unsupported class ${class} in division",
SimpleMap.of("class", clazz.getName())));
}

public static Multiplication getMultiplication(Class<?> clazz) {
if (clazz == long.class) {
if (clazz == int.class) {
return Multiplication.INTEGER;
} else if (clazz == long.class) {
return Multiplication.LONG;
} else if (clazz == float.class) {
return Multiplication.FLOAT;
} else if (clazz == double.class) {
return Multiplication.DOUBLE;
}
return Multiplication.INTEGER;
throw new Ts2JavaException(
SimpleFreeMarkerFormat.format("Unsupported class ${class} in multiplication",
SimpleMap.of("class", clazz.getName())));
}

public static Remainder getRemainder(Class<?> clazz) {
if (clazz == int.class) {
return Remainder.INTEGER;
} else if (clazz == long.class) {
return Remainder.LONG;
} else if (clazz == float.class) {
return Remainder.FLOAT;
} else if (clazz == double.class) {
return Remainder.DOUBLE;
}
throw new Ts2JavaException(
SimpleFreeMarkerFormat.format("Unsupported class ${class} in mod",
SimpleMap.of("class", clazz.getName())));
}

public static ShiftLeft getShiftLeft(Class<?> clazz) {
if (clazz == int.class) {
return ShiftLeft.INTEGER;
} else if (clazz == long.class) {
return ShiftLeft.LONG;
}
throw new Ts2JavaException(
SimpleFreeMarkerFormat.format("Unsupported class ${class} in left shift",
SimpleMap.of("class", clazz.getName())));
}

public static ShiftRight getShiftRight(Class<?> clazz) {
if (clazz == int.class) {
return ShiftRight.INTEGER;
} else if (clazz == long.class) {
return ShiftRight.LONG;
}
throw new Ts2JavaException(
SimpleFreeMarkerFormat.format("Unsupported class ${class} in right shift",
SimpleMap.of("class", clazz.getName())));
}

public static Subtraction getSubtraction(Class<?> clazz) {
if (clazz == long.class) {
if (clazz == int.class) {
return Subtraction.INTEGER;
} else if (clazz == long.class) {
return Subtraction.LONG;
} else if (clazz == float.class) {
return Subtraction.FLOAT;
} else if (clazz == double.class) {
return Subtraction.DOUBLE;
}
return Subtraction.INTEGER;
throw new Ts2JavaException(
SimpleFreeMarkerFormat.format("Unsupported class ${class} in subtraction",
SimpleMap.of("class", clazz.getName())));
}

public static Class<?> getUpCastClassForMathOp(Class<?>... classes) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@

import static org.junit.jupiter.api.Assertions.*;

public class TestFourBasicOperations extends BaseTestTs2Java {
protected Class<?> clazz;
public class TestBasicOperations extends BaseTestTs2Java {
protected static Class<?> clazz = null;

public TestFourBasicOperations() {
public TestBasicOperations() {
super();
init();
}
Expand Down Expand Up @@ -61,24 +61,26 @@ public long add(int a, long b) {
}

protected void init() {
String tsCode = null;
try {
tsCode = getTsCode("test.four.basic.operations.ts");
} catch (IOException e) {
fail(e);
if (clazz == null) {
String tsCode = null;
try {
tsCode = getTsCode("test.basic.operations.ts");
} catch (IOException e) {
fail(e);
}
assertNotNull(tsCode);
Ts2Java ts2Java = new Ts2Java("com.test", tsCode);
try {
ts2Java.transpile();
} catch (Swc4jCoreException e) {
fail(e);
}
List<Class<?>> classes = ts2Java.getClasses();
assertEquals(1, classes.size());
clazz = classes.get(0);
assertEquals("Test", clazz.getSimpleName());
assertEquals("com.test.Test", clazz.getName());
}
assertNotNull(tsCode);
Ts2Java ts2Java = new Ts2Java("com.test", tsCode);
try {
ts2Java.transpile();
} catch (Swc4jCoreException e) {
fail(e);
}
List<Class<?>> classes = ts2Java.getClasses();
assertEquals(1, classes.size());
clazz = classes.get(0);
assertEquals("Test", clazz.getSimpleName());
assertEquals("com.test.Test", clazz.getName());
}

@Test
Expand All @@ -91,7 +93,7 @@ public void testAdd_II_I() throws Exception {
assertEquals(int.class, method.getParameters()[0].getType());
assertEquals(int.class, method.getParameters()[1].getType());
Object object = clazz.getConstructor().newInstance();
assertEquals(3, method.invoke(object, 1, 2));
assertEquals(1 + 2, method.invoke(object, 1, 2));
}

@Test
Expand All @@ -104,7 +106,7 @@ public void testAdd_IL_L() throws Exception {
assertEquals(int.class, method.getParameters()[0].getType());
assertEquals(long.class, method.getParameters()[1].getType());
Object object = clazz.getConstructor().newInstance();
assertEquals(3L, method.invoke(object, 1, 2L));
assertEquals(1 + 2L, method.invoke(object, 1, 2L));
}

@Test
Expand All @@ -116,7 +118,19 @@ public void testDivide_II_I() throws Exception {
assertEquals(int.class, method.getParameters()[0].getType());
assertEquals(int.class, method.getParameters()[1].getType());
Object object = clazz.getConstructor().newInstance();
assertEquals(1, method.invoke(object, 3, 2));
assertEquals(3 / 2, method.invoke(object, 3, 2));
}

@Test
public void testMod_II_I() throws Exception {
Method method = clazz.getMethod("mod_II_I", int.class, int.class);
assertNotNull(method);
assertEquals(int.class, method.getReturnType());
assertEquals(2, method.getParameterCount());
assertEquals(int.class, method.getParameters()[0].getType());
assertEquals(int.class, method.getParameters()[1].getType());
Object object = clazz.getConstructor().newInstance();
assertEquals(3 % 2, method.invoke(object, 3, 2));
}

@Test
Expand All @@ -128,7 +142,31 @@ public void testMultiply_II_I() throws Exception {
assertEquals(int.class, method.getParameters()[0].getType());
assertEquals(int.class, method.getParameters()[1].getType());
Object object = clazz.getConstructor().newInstance();
assertEquals(6, method.invoke(object, 3, 2));
assertEquals(3 * 2, method.invoke(object, 3, 2));
}

@Test
public void testShiftLeft_II_I() throws Exception {
Method method = clazz.getMethod("shift_left_II_I", int.class, int.class);
assertNotNull(method);
assertEquals(int.class, method.getReturnType());
assertEquals(2, method.getParameterCount());
assertEquals(int.class, method.getParameters()[0].getType());
assertEquals(int.class, method.getParameters()[1].getType());
Object object = clazz.getConstructor().newInstance();
assertEquals(3 << 2, method.invoke(object, 3, 2));
}

@Test
public void testShiftRight_II_I() throws Exception {
Method method = clazz.getMethod("shift_right_II_I", int.class, int.class);
assertNotNull(method);
assertEquals(int.class, method.getReturnType());
assertEquals(2, method.getParameterCount());
assertEquals(int.class, method.getParameters()[0].getType());
assertEquals(int.class, method.getParameters()[1].getType());
Object object = clazz.getConstructor().newInstance();
assertEquals(3 >> 1, method.invoke(object, 3, 1));
}

@Test
Expand All @@ -140,6 +178,6 @@ public void testSubtract_II_I() throws Exception {
assertEquals(int.class, method.getParameters()[0].getType());
assertEquals(int.class, method.getParameters()[1].getType());
Object object = clazz.getConstructor().newInstance();
assertEquals(1, method.invoke(object, 3, 2));
assertEquals(3 - 2, method.invoke(object, 3, 2));
}
}

0 comments on commit 54ccc39

Please sign in to comment.